From b7f9e431a5459c0cff51ba4ebd80e20235baeaef Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Wed, 22 Jan 2025 09:39:07 -0800 Subject: [PATCH] pydantic for LangGraph + changed ERT extraction flow --- .../onyx/agents/agent_search/core_state.py | 11 +- .../answer_initial_sub_question/edges.py | 7 +- .../graph_builder.py | 1 + .../nodes/answer_check.py | 6 +- .../nodes/answer_generation.py | 6 +- .../nodes/format_answer.py | 16 +-- .../nodes/ingest_retrieval.py | 12 +- .../answer_initial_sub_question/states.py | 29 ++--- .../answer_refinement_sub_question/edges.py | 8 +- .../graph_builder.py | 1 + .../nodes/format_raw_search_results.py | 4 +- .../nodes/generate_raw_search_data.py | 1 + .../deep_search_a/base_raw_search/states.py | 6 +- .../deep_search_a/expanded_retrieval/edges.py | 9 +- .../expanded_retrieval/graph_builder.py | 15 ++- .../expanded_retrieval/models.py | 6 +- .../expanded_retrieval/nodes/doc_reranking.py | 4 +- .../expanded_retrieval/nodes/doc_retrieval.py | 2 +- .../nodes/doc_verification.py | 4 +- .../expanded_retrieval/nodes/dummy.py | 16 +++ .../nodes/expand_queries.py | 8 +- .../nodes/format_results.py | 26 ++--- .../nodes/verification_kickoff.py | 11 +- .../expanded_retrieval/states.py | 41 +++---- .../agent_search/deep_search_a/main/edges.py | 12 +- .../deep_search_a/main/graph_builder.py | 38 ++++++- .../agent_search/deep_search_a/main/models.py | 14 +-- .../deep_search_a/main/nodes/agent_logging.py | 14 +-- .../main/nodes/agent_path_decision.py | 7 +- .../main/nodes/agent_path_routing.py | 2 +- .../main/nodes/entity_term_extraction_llm.py | 10 +- .../main/nodes/generate_initial_answer.py | 24 ++-- ...enerate_initial_base_search_only_answer.py | 2 +- .../main/nodes/generate_refined_answer.py | 42 +++---- .../nodes/ingest_initial_base_retrieval.py | 14 +-- .../ingest_initial_sub_question_answers.py | 2 +- .../main/nodes/ingest_refined_answers.py | 2 +- .../nodes/initial_sub_question_creation.py | 2 +- .../main/nodes/refined_answer_decision.py | 4 +- .../nodes/refined_sub_question_creation.py | 8 +- .../main/nodes/retrieval_consolidation.py | 12 ++ .../agent_search/deep_search_a/main/states.py | 104 ++++++++++-------- backend/onyx/agents/agent_search/models.py | 8 +- backend/onyx/agents/agent_search/run_graph.py | 13 ++- .../agent_search/shared_graph_utils/models.py | 34 +++--- 45 files changed, 364 insertions(+), 254 deletions(-) create mode 100644 backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/dummy.py create mode 100644 backend/onyx/agents/agent_search/deep_search_a/main/nodes/retrieval_consolidation.py diff --git a/backend/onyx/agents/agent_search/core_state.py b/backend/onyx/agents/agent_search/core_state.py index 693356b51..3fb00735d 100644 --- a/backend/onyx/agents/agent_search/core_state.py +++ b/backend/onyx/agents/agent_search/core_state.py @@ -1,18 +1,19 @@ from operator import add from typing import Annotated -from typing import TypedDict + +from pydantic import BaseModel -class CoreState(TypedDict, total=False): +class CoreState(BaseModel): """ This is the core state that is shared across all subgraphs. """ - base_question: str - log_messages: Annotated[list[str], add] + base_question: str = "" + log_messages: Annotated[list[str], add] = [] -class SubgraphCoreState(TypedDict, total=False): +class SubgraphCoreState(BaseModel): """ This is the core state that is shared across all subgraphs. """ diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py index a67e508f7..aa9ffafde 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py @@ -1,4 +1,5 @@ from collections.abc import Hashable +from datetime import datetime from langgraph.types import Send @@ -15,12 +16,14 @@ logger = setup_logger() def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: logger.debug("sending to expanded retrieval via edge") + now_start = datetime.now() return Send( "initial_sub_question_expanded_retrieval", ExpandedRetrievalInput( - question=state["question"], + question=state.question, base_search=False, - sub_question_id=state["question_id"], + sub_question_id=state.question_id, + log_messages=[f"{now_start} -- Sending to expanded retrieval"], ), ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py index 5af71933f..a6d9c6ecb 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py @@ -115,6 +115,7 @@ if __name__ == "__main__": inputs = AnswerQuestionInput( question="what can you do with onyx?", question_id="0_0", + log_messages=[], ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py index 6fe5e7a9e..5181bc504 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py @@ -17,15 +17,15 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate: - if state["answer"] == UNKNOWN_ANSWER: + if state.answer == UNKNOWN_ANSWER: return QACheckUpdate( answer_quality=SUB_CHECK_NO, ) msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( - question=state["question"], - base_answer=state["answer"], + question=state.question, + base_answer=state.answer, ) ) ] diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py index fa482f58a..b09696490 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py @@ -40,9 +40,9 @@ def answer_generation( logger.debug(f"--------{now_start}--------START ANSWER GENERATION---") agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) - question = state["question"] - docs = state["documents"] - level, question_nr = parse_question_id(state["question_id"]) + question = state.question + docs = state.documents + level, question_nr = parse_question_id(state.question_id) persona_prompt = get_persona_prompt(agent_search_config.search_request.persona) if len(docs) == 0: diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py index a82cf83bd..716167dc6 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py @@ -13,13 +13,15 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: return AnswerQuestionOutput( answer_results=[ QuestionAnswerResults( - question=state["question"], - question_id=state["question_id"], - quality=state.get("answer_quality", "No"), - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - sub_question_retrieval_stats=state["sub_question_retrieval_stats"], + question=state.question, + question_id=state.question_id, + quality=state.answer_quality + if hasattr(state, "answer_quality") + else "No", + answer=state.answer, + expanded_retrieval_results=state.expanded_retrieval_results, + documents=state.documents, + sub_question_retrieval_stats=state.sub_question_retrieval_stats, ) ], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py index 3b4c305c2..fb4c4ce73 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py @@ -8,16 +8,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: - sub_question_retrieval_stats = state[ - "expanded_retrieval_result" - ].sub_question_retrieval_stats + sub_question_retrieval_stats = ( + state.expanded_retrieval_result.sub_question_retrieval_stats + ) if sub_question_retrieval_stats is None: sub_question_retrieval_stats = [AgentChunkStats()] return RetrievalIngestionUpdate( - expanded_retrieval_results=state[ - "expanded_retrieval_result" - ].expanded_queries_results, - documents=state["expanded_retrieval_result"].all_documents, + expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results, + documents=state.expanded_retrieval_result.all_documents, sub_question_retrieval_stats=sub_question_retrieval_stats, ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py index 98f464dce..1f283d915 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py @@ -1,6 +1,7 @@ from operator import add from typing import Annotated -from typing import TypedDict + +from pydantic import BaseModel from onyx.agents.agent_search.core_state import SubgraphCoreState from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats @@ -15,27 +16,29 @@ from onyx.context.search.models import InferenceSection ## Update States -class QACheckUpdate(TypedDict): - answer_quality: str +class QACheckUpdate(BaseModel): + answer_quality: str = "" -class QAGenerationUpdate(TypedDict): - answer: str +class QAGenerationUpdate(BaseModel): + answer: str = "" # answer_stat: AnswerStats -class RetrievalIngestionUpdate(TypedDict): - expanded_retrieval_results: list[QueryResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: AgentChunkStats +class RetrievalIngestionUpdate(BaseModel): + expanded_retrieval_results: list[QueryResult] = [] + documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] + sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats() ## Graph Input State class AnswerQuestionInput(SubgraphCoreState): - question: str - question_id: str # 0_0 is original question, everything else is _. + question: str = "" + question_id: str = ( + "" # 0_0 is original question, everything else is _. + ) # level 0 is original question and first decomposition, level 1 is follow up, etc # question_num is a unique number per original question per level. @@ -55,11 +58,11 @@ class AnswerQuestionState( ## Graph Output State -class AnswerQuestionOutput(TypedDict): +class AnswerQuestionOutput(BaseModel): """ This is a list of results even though each call of this subgraph only returns one result. This is because if we parallelize the answer query subgraph, there will be multiple results in a list so the add operator is used to add them together. """ - answer_results: Annotated[list[QuestionAnswerResults], add] + answer_results: Annotated[list[QuestionAnswerResults], add] = [] diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py index 2a5fdc148..201ac6a4a 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py @@ -1,4 +1,5 @@ from collections.abc import Hashable +from datetime import datetime from langgraph.types import Send @@ -15,12 +16,13 @@ logger = setup_logger() def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable: logger.debug("sending to expanded retrieval for follow up question via edge") - + datetime.now() return Send( "refined_sub_question_expanded_retrieval", ExpandedRetrievalInput( - question=state["question"], - sub_question_id=state["question_id"], + question=state.question, + sub_question_id=state.question_id, base_search=False, + log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"], ), ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py index 3598c4dc6..fd991aef2 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py @@ -111,6 +111,7 @@ if __name__ == "__main__": inputs = AnswerQuestionInput( question="what can you do with onyx?", question_id="0_0", + log_messages=[], ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py index 527a01323..92816a93c 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py @@ -12,7 +12,7 @@ logger = setup_logger() def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: logger.debug("format_raw_search_results") return BaseRawSearchOutput( - base_expanded_retrieval_result=state["expanded_retrieval_result"], - # base_retrieval_results=[state["expanded_retrieval_result"]], + base_expanded_retrieval_result=state.expanded_retrieval_result, + # base_retrieval_results=[state.expanded_retrieval_result], # base_search_documents=[], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py index 22e69eee9..5526745b6 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py @@ -21,4 +21,5 @@ def generate_raw_search_data( question=agent_a_config.search_request.query, base_search=True, sub_question_id=None, # This graph is always and only used for the original question + log_messages=[], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py index 90676e77e..ed45b7d4d 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py @@ -1,4 +1,4 @@ -from typing import TypedDict +from pydantic import BaseModel from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( ExpandedRetrievalResult, @@ -21,7 +21,7 @@ class BaseRawSearchInput(ExpandedRetrievalInput): ## Graph Output State -class BaseRawSearchOutput(TypedDict): +class BaseRawSearchOutput(BaseModel): """ This is a list of results even though each call of this subgraph only returns one result. This is because if we parallelize the answer query subgraph, there will be multiple @@ -30,7 +30,7 @@ class BaseRawSearchOutput(TypedDict): # base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections] # base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] - base_expanded_retrieval_result: ExpandedRetrievalResult + base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult() ## Graph State diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py index 6a2db1402..b76d12755 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py @@ -17,9 +17,11 @@ def parallel_retrieval_edge( state: ExpandedRetrievalState, config: RunnableConfig ) -> list[Send | Hashable]: agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - question = state.get("question", agent_a_config.search_request.query) + question = state.question if state.question else agent_a_config.search_request.query - query_expansions = state.get("expanded_queries", []) + [question] + query_expansions = ( + state.expanded_queries if state.expanded_queries else [] + [question] + ) return [ Send( "doc_retrieval", @@ -27,7 +29,8 @@ def parallel_retrieval_edge( query_to_retrieve=query, question=question, base_search=False, - sub_question_id=state.get("sub_question_id"), + sub_question_id=state.sub_question_id, + log_messages=[], ), ) for query in query_expansions diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py index 1f0d88ece..2251bb6bd 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py @@ -14,6 +14,9 @@ from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retriev from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import ( + dummy, +) from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import ( expand_queries, ) @@ -52,6 +55,11 @@ def expanded_retrieval_graph_builder() -> StateGraph: action=expand_queries, ) + graph.add_node( + node="dummy", + action=dummy, + ) + graph.add_node( node="doc_retrieval", action=doc_retrieval, @@ -78,9 +86,13 @@ def expanded_retrieval_graph_builder() -> StateGraph: start_key=START, end_key="expand_queries", ) + graph.add_edge( + start_key="expand_queries", + end_key="dummy", + ) graph.add_conditional_edges( - source="expand_queries", + source="dummy", path=parallel_retrieval_edge, path_map=["doc_retrieval"], ) @@ -124,6 +136,7 @@ if __name__ == "__main__": question="what can you do with onyx?", base_search=False, sub_question_id=None, + log_messages=[], ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py index 139f3311a..f564bacab 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py @@ -6,6 +6,6 @@ from onyx.context.search.models import InferenceSection class ExpandedRetrievalResult(BaseModel): - expanded_queries_results: list[QueryResult] - all_documents: list[InferenceSection] - sub_question_retrieval_stats: AgentChunkStats + expanded_queries_results: list[QueryResult] = [] + all_documents: list[InferenceSection] = [] + sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats() diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py index 03a50af61..57ef69816 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py @@ -24,13 +24,13 @@ from onyx.db.engine import get_session_context_manager def doc_reranking( state: ExpandedRetrievalState, config: RunnableConfig ) -> DocRerankingUpdate: - verified_documents = state["verified_documents"] + verified_documents = state.verified_documents # Rerank post retrieval and verification. First, create a search query # then create the list of reranked sections agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - question = state.get("question", agent_a_config.search_request.query) + question = state.question if state.question else agent_a_config.search_request.query with get_session_context_manager() as db_session: _search_query = retrieval_preprocessing( search_request=SearchRequest(query=question), diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py index 899f1d677..4820cb5b4 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py @@ -35,7 +35,7 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval expanded_retrieval_results: list[ExpandedRetrievalResult] retrieved_documents: list[InferenceSection] """ - query_to_retrieve = state["query_to_retrieve"] + query_to_retrieve = state.query_to_retrieve agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) search_tool = agent_a_config.search_tool diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py index fbd10b30d..e96f809da 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py @@ -30,8 +30,8 @@ def doc_verification( verified_documents: list[InferenceSection] """ - question = state["question"] - doc_to_verify = state["doc_to_verify"] + question = state.question + doc_to_verify = state.doc_to_verify document_content = doc_to_verify.combined_content agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/dummy.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/dummy.py new file mode 100644 index 000000000..93ff3dd23 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/dummy.py @@ -0,0 +1,16 @@ +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + QueryExpansionUpdate, +) + + +def dummy( + state: ExpandedRetrievalState, config: RunnableConfig +) -> QueryExpansionUpdate: + return QueryExpansionUpdate( + expanded_queries=state.expanded_queries, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py index b0b168d4c..921b040e2 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py @@ -28,10 +28,14 @@ def expand_queries( # When we are running this node on the original question, no question is explictly passed in. # Instead, we use the original question from the search request. agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - question = state.get("question", agent_a_config.search_request.query) + question = ( + state.question + if hasattr(state, "question") + else agent_a_config.search_request.query + ) llm = agent_a_config.fast_llm chat_session_id = agent_a_config.chat_session_id - sub_question_id = state.get("sub_question_id") + sub_question_id = state.sub_question_id if sub_question_id is None: level, question_nr = 0, 0 else: diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py index b3ebf5de9..0401ba5a5 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py @@ -25,10 +25,10 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp def format_results( state: ExpandedRetrievalState, config: RunnableConfig ) -> ExpandedRetrievalUpdate: - level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0") + level, question_nr = parse_question_id(state.sub_question_id or "0_0") query_infos = [ result.query_info - for result in state["expanded_retrieval_results"] + for result in state.expanded_retrieval_results if result.query_info is not None ] if len(query_infos) == 0: @@ -37,19 +37,15 @@ def format_results( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) # main question docs will be sent later after aggregation and deduping with sub-question docs if not (level == 0 and question_nr == 0): - if len(state["reranked_documents"]) > 0: - stream_documents = state["reranked_documents"] + if len(state.reranked_documents) > 0: + stream_documents = state.reranked_documents else: # The sub-question is used as the last query. If no verified documents are found, stream # the top 3 for that one. We may want to revisit this. - stream_documents = state["expanded_retrieval_results"][-1].search_results[ - :3 - ] + stream_documents = state.expanded_retrieval_results[-1].search_results[:3] for tool_response in yield_search_responses( - query=state["question"], - reranked_sections=state[ - "retrieved_documents" - ], # TODO: rename params. this one is supposed to be the sections pre-merging + query=state.question, + reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.) final_context_sections=stream_documents, search_query_info=query_infos[0], # TODO: handle differing query infos? get_section_relevance=lambda: None, # TODO: add relevance @@ -65,8 +61,8 @@ def format_results( ), ) sub_question_retrieval_stats = calculate_sub_question_retrieval_stats( - verified_documents=state["verified_documents"], - expanded_retrieval_results=state["expanded_retrieval_results"], + verified_documents=state.verified_documents, + expanded_retrieval_results=state.expanded_retrieval_results, ) if sub_question_retrieval_stats is None: @@ -76,8 +72,8 @@ def format_results( return ExpandedRetrievalUpdate( expanded_retrieval_result=ExpandedRetrievalResult( - expanded_queries_results=state["expanded_retrieval_results"], - all_documents=state["reranked_documents"], + expanded_queries_results=state.expanded_retrieval_results, + all_documents=state.reranked_documents, sub_question_retrieval_stats=sub_question_retrieval_stats, ), ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py index 744242b72..55e813d12 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py @@ -18,10 +18,14 @@ def verification_kickoff( state: ExpandedRetrievalState, config: RunnableConfig, ) -> Command[Literal["doc_verification"]]: - documents = state["retrieved_documents"] + documents = state.retrieved_documents agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - verification_question = state.get("question", agent_a_config.search_request.query) - sub_question_id = state.get("sub_question_id") + verification_question = ( + state.question + if hasattr(state, "question") + else agent_a_config.search_request.query + ) + sub_question_id = state.sub_question_id return Command( update={}, goto=[ @@ -32,6 +36,7 @@ def verification_kickoff( question=verification_question, base_search=False, sub_question_id=sub_question_id, + log_messages=[], ), ) for doc in documents diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py index 2572b8390..f459437f7 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py @@ -1,6 +1,7 @@ from operator import add from typing import Annotated -from typing import TypedDict + +from pydantic import BaseModel from onyx.agents.agent_search.core_state import SubgraphCoreState from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( @@ -20,42 +21,44 @@ from onyx.context.search.models import InferenceSection class ExpandedRetrievalInput(SubgraphCoreState): - question: str - base_search: bool - sub_question_id: str | None + question: str = "" + base_search: bool = False + sub_question_id: str | None = None ## Update/Return States -class QueryExpansionUpdate(TypedDict): - expanded_queries: list[str] +class QueryExpansionUpdate(BaseModel): + expanded_queries: list[str] = ["aaa", "bbb"] -class DocVerificationUpdate(TypedDict): - verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class DocVerificationUpdate(BaseModel): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] -class DocRetrievalUpdate(TypedDict): - expanded_retrieval_results: Annotated[list[QueryResult], add] - retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class DocRetrievalUpdate(BaseModel): + expanded_retrieval_results: Annotated[list[QueryResult], add] = [] + retrieved_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] = [] -class DocRerankingUpdate(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: RetrievalFitStats | None +class DocRerankingUpdate(BaseModel): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] + sub_question_retrieval_stats: RetrievalFitStats | None = None -class ExpandedRetrievalUpdate(TypedDict): +class ExpandedRetrievalUpdate(BaseModel): expanded_retrieval_result: ExpandedRetrievalResult ## Graph Output State -class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_result: ExpandedRetrievalResult - base_expanded_retrieval_result: ExpandedRetrievalResult +class ExpandedRetrievalOutput(BaseModel): + expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult() + base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult() ## Graph State @@ -81,4 +84,4 @@ class DocVerificationInput(ExpandedRetrievalInput): class RetrievalInput(ExpandedRetrievalInput): - query_to_retrieve: str + query_to_retrieve: str = "" diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/edges.py b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py index ed7fca7f0..728fa9a66 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py @@ -22,7 +22,7 @@ logger = setup_logger() def parallelize_initial_sub_question_answering( state: MainState, ) -> list[Send | Hashable]: - if len(state["initial_decomp_questions"]) > 0: + if len(state.initial_decomp_questions) > 0: # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # if len(state["sub_question_records"]) == 0: # if state["config"].use_persistence: @@ -39,9 +39,10 @@ def parallelize_initial_sub_question_answering( AnswerQuestionInput( question=question, question_id=make_question_id(0, question_nr + 1), + log_messages=[], ), ) - for question_nr, question in enumerate(state["initial_decomp_questions"]) + for question_nr, question in enumerate(state.initial_decomp_questions) ] else: @@ -59,7 +60,7 @@ def parallelize_initial_sub_question_answering( def continue_to_refined_answer_or_end( state: RequireRefinedAnswerUpdate, ) -> Literal["refined_sub_question_creation", "logging_node"]: - if state["require_refined_answer"]: + if state.require_refined_answer: return "refined_sub_question_creation" else: return "logging_node" @@ -68,16 +69,17 @@ def continue_to_refined_answer_or_end( def parallelize_refined_sub_question_answering( state: MainState, ) -> list[Send | Hashable]: - if len(state["refined_sub_questions"]) > 0: + if len(state.refined_sub_questions) > 0: return [ Send( "answer_refined_question", AnswerQuestionInput( question=question_data.sub_question, question_id=make_question_id(1, question_nr), + log_messages=[], ), ) - for question_nr, question_data in state["refined_sub_questions"].items() + for question_nr, question_data in state.refined_sub_questions.items() ] else: diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py index d99c8d33d..1edbe75a9 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py @@ -65,6 +65,9 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision i from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import ( refined_sub_question_creation, ) +from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import ( + retrieval_consolidation, +) from onyx.agents.agent_search.deep_search_a.main.states import MainInput from onyx.agents.agent_search.deep_search_a.main.states import MainState from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config @@ -153,6 +156,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: node="ingest_initial_retrieval", action=ingest_initial_base_retrieval, ) + + graph.add_node( + node="retrieval_consolidation", + action=retrieval_consolidation, + ) + graph.add_node( node="ingest_initial_sub_question_answers", action=ingest_initial_sub_question_answers, @@ -215,6 +224,21 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: end_key="ingest_initial_retrieval", ) + graph.add_edge( + start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"], + end_key="retrieval_consolidation", + ) + + graph.add_edge( + start_key="retrieval_consolidation", + end_key="entity_term_extraction_llm", + ) + + graph.add_edge( + start_key="retrieval_consolidation", + end_key="generate_initial_answer", + ) + graph.add_edge( start_key="LLM", end_key=END, @@ -236,14 +260,14 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ) graph.add_edge( - start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"], + start_key="retrieval_consolidation", end_key="generate_initial_answer", ) - graph.add_edge( - start_key="generate_initial_answer", - end_key="entity_term_extraction_llm", - ) + # graph.add_edge( + # start_key="generate_initial_answer", + # end_key="entity_term_extraction_llm", + # ) graph.add_edge( start_key="generate_initial_answer", @@ -327,7 +351,9 @@ if __name__ == "__main__": db_session, primary_llm, fast_llm, search_request ) - inputs = MainInput() + inputs = MainInput( + base_question=agent_a_config.search_request.query, log_messages=[] + ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/models.py b/backend/onyx/agents/agent_search/deep_search_a/main/models.py index 2bb487cb8..e745f5622 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/models.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/models.py @@ -20,16 +20,16 @@ class AgentBaseMetrics(BaseModel): num_verified_documents_core: int | None verified_avg_score_core: float | None num_verified_documents_base: int | float | None - verified_avg_score_base: float | None - base_doc_boost_factor: float | None - support_boost_factor: float | None - duration__s: float | None + verified_avg_score_base: float | None = None + base_doc_boost_factor: float | None = None + support_boost_factor: float | None = None + duration__s: float | None = None class AgentRefinedMetrics(BaseModel): - refined_doc_boost_factor: float | None - refined_question_boost_factor: float | None - duration__s: float | None + refined_doc_boost_factor: float | None = None + refined_question_boost_factor: float | None = None + duration__s: float | None = None class AgentAdditionalMetrics(BaseModel): diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py index d17f4f45b..bc16b6e09 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py @@ -19,10 +19,10 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: logger.debug(f"--------{now_start}--------LOGGING NODE---") - agent_start_time = state["agent_start_time"] - agent_base_end_time = state["agent_base_end_time"] - agent_refined_start_time = state["agent_refined_start_time"] or None - agent_refined_end_time = state["agent_refined_end_time"] or None + agent_start_time = state.agent_start_time + agent_base_end_time = state.agent_base_end_time + agent_refined_start_time = state.agent_refined_start_time or None + agent_refined_end_time = state.agent_refined_end_time or None agent_end_time = agent_refined_end_time or agent_base_end_time agent_base_duration = None @@ -41,8 +41,8 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: agent_type = "refined" if agent_refined_duration else "base" - agent_base_metrics = state["agent_base_metrics"] - agent_refined_metrics = state["agent_refined_metrics"] + agent_base_metrics = state.agent_base_metrics + agent_refined_metrics = state.agent_refined_metrics combined_agent_metrics = CombinedAgentMetrics( timings=AgentTimings( @@ -81,7 +81,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: db_session = agent_a_config.db_session chat_session_id = agent_a_config.chat_session_id primary_message_id = agent_a_config.message_id - sub_question_answer_results = state["decomp_answer_results"] + sub_question_answer_results = state.decomp_answer_results log_agent_sub_question_results( db_session=db_session, diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py index 3dcd3d1a3..639464e83 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py @@ -17,6 +17,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ( ) from onyx.context.search.models import InferenceSection from onyx.db.engine import get_session_context_manager +from onyx.llm.utils import check_number_of_tokens from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) @@ -86,9 +87,13 @@ def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDeci f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---" ) + check_number_of_tokens(agent_decision_prompt) + return RoutingDecision( # Decide which route to take routing=routing, sample_doc_str=sample_doc_str, - log_messages=[f"Path decision: {routing}, Time taken: {now_end - now_start}"], + log_messages=[ + f"{now_start} -- Path decision: {routing}, Time taken: {now_end - now_start}" + ], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py index 74c9d78f9..20af5d982 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py @@ -8,7 +8,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import MainState def agent_path_routing( state: MainState, ) -> Command[Literal["agent_search_start", "LLM"]]: - routing = state.get("routing", "agent_search") + routing = state.routing if hasattr(state, "routing") else "agent_search" if routing == "agent_search": agent_path = "agent_search_start" diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py index 381b9e48f..983a918cd 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py @@ -48,14 +48,13 @@ def entity_term_extraction_llm( # first four lines duplicates from generate_initial_answer question = agent_a_config.search_request.query - sub_question_docs = state["documents"] - all_original_question_documents = state["all_original_question_documents"] + sub_question_docs = state.documents + all_original_question_documents = state.all_original_question_documents relevant_docs = dedup_inference_sections( sub_question_docs, all_original_question_documents ) # start with the entity/term/extraction - doc_context = format_docs(relevant_docs) doc_context = trim_prompt_piece( @@ -127,5 +126,8 @@ def entity_term_extraction_llm( entities=entities, relationships=relationships, terms=terms, - ) + ), + log_messages=[ + f"{now_start} -- Entity Term Extraction - Time taken: {now_end - now_start}" + ], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py index 5dd2c44c1..5020f456f 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py @@ -64,8 +64,8 @@ def generate_initial_answer( history = build_history_prompt(agent_a_config.message_history) - sub_question_docs = state["documents"] - all_original_question_documents = state["all_original_question_documents"] + sub_question_docs = state.documents + all_original_question_documents = state.all_original_question_documents relevant_docs = dedup_inference_sections( sub_question_docs, all_original_question_documents @@ -92,7 +92,7 @@ def generate_initial_answer( else: # Use the query info from the base document retrieval - query_info = get_query_info(state["original_question_retrieval_results"]) + query_info = get_query_info(state.original_question_retrieval_results) for tool_response in yield_search_responses( query=question, @@ -117,7 +117,7 @@ def generate_initial_answer( if all_original_question_doc not in sub_question_docs: net_new_original_question_docs.append(all_original_question_doc) - decomp_answer_results = state["decomp_answer_results"] + decomp_answer_results = state.decomp_answer_results good_qa_list: list[str] = [] @@ -206,7 +206,7 @@ def generate_initial_answer( answer = cast(str, response) initial_agent_stats = calculate_initial_agent_stats( - state["decomp_answer_results"], state["original_question_retrieval_stats"] + state.decomp_answer_results, state.original_question_retrieval_stats ) logger.debug( @@ -228,12 +228,8 @@ def generate_initial_answer( agent_base_metrics = AgentBaseMetrics( num_verified_documents_total=len(relevant_docs), - num_verified_documents_core=state[ - "original_question_retrieval_stats" - ].verified_count, - verified_avg_score_core=state[ - "original_question_retrieval_stats" - ].verified_avg_scores, + num_verified_documents_core=state.original_question_retrieval_stats.verified_count, + verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores, num_verified_documents_base=initial_agent_stats.sub_questions.get( "num_verified_documents", None ), @@ -246,7 +242,7 @@ def generate_initial_answer( support_boost_factor=initial_agent_stats.agent_effectiveness.get( "support_ratio", None ), - duration__s=(agent_base_end_time - state["agent_start_time"]).total_seconds(), + duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(), ) return InitialAnswerUpdate( @@ -255,5 +251,7 @@ def generate_initial_answer( generated_sub_questions=decomp_questions, agent_base_end_time=agent_base_end_time, agent_base_metrics=agent_base_metrics, - log_messages=[f"Initial answer generation: {now_end - now_start}"], + log_messages=[ + f"{now_start} -- Initial answer generation - Time taken: {now_end - now_start}" + ], ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py index c1afbef2f..ab3e00855 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py @@ -25,7 +25,7 @@ def generate_initial_base_search_only_answer( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = agent_a_config.search_request.query - original_question_docs = state["all_original_question_documents"] + original_question_docs = state.all_original_question_documents model = agent_a_config.fast_llm diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py index b966ffe96..b7dc5cd73 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -61,12 +61,12 @@ def generate_refined_answer( history = build_history_prompt(agent_a_config.message_history) - initial_documents = state["documents"] - revised_documents = state["refined_documents"] + initial_documents = state.documents + revised_documents = state.refined_documents combined_documents = dedup_inference_sections(initial_documents, revised_documents) - query_info = get_query_info(state["original_question_retrieval_results"]) + query_info = get_query_info(state.original_question_retrieval_results) # stream refined answer docs for tool_response in yield_search_responses( query=question, @@ -93,8 +93,8 @@ def generate_refined_answer( else: revision_doc_effectiveness = 10.0 - decomp_answer_results = state["decomp_answer_results"] - # revised_answer_results = state["refined_decomp_answer_results"] + decomp_answer_results = state.decomp_answer_results + # revised_answer_results = state.refined_decomp_answer_results good_qa_list: list[str] = [] decomp_questions = [] @@ -147,7 +147,7 @@ def generate_refined_answer( # original answer - initial_answer = state["initial_answer"] + initial_answer = state.initial_answer # Determine which persona-specification prompt to use @@ -218,7 +218,7 @@ def generate_refined_answer( answer = cast(str, response) # refined_agent_stats = _calculate_refined_agent_stats( - # state["decomp_answer_results"], state["original_question_retrieval_stats"] + # state.decomp_answer_results, state.original_question_retrieval_stats # ) initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions))) @@ -252,22 +252,22 @@ def generate_refined_answer( logger.debug("-" * 100) - if state["initial_agent_stats"]: - initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get( + if state.initial_agent_stats: + initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get( "utilized_chunk_ratio", "--" ) - initial_support_boost_factor = state[ - "initial_agent_stats" - ].agent_effectiveness.get("support_ratio", "--") - num_initial_verified_docs = state["initial_agent_stats"].original_question.get( + initial_support_boost_factor = ( + state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--") + ) + num_initial_verified_docs = state.initial_agent_stats.original_question.get( "num_verified_documents", "--" ) - initial_verified_docs_avg_score = state[ - "initial_agent_stats" - ].original_question.get("verified_avg_score", "--") - initial_sub_questions_verified_docs = state[ - "initial_agent_stats" - ].sub_questions.get("num_verified_documents", "--") + initial_verified_docs_avg_score = ( + state.initial_agent_stats.original_question.get("verified_avg_score", "--") + ) + initial_sub_questions_verified_docs = ( + state.initial_agent_stats.sub_questions.get("num_verified_documents", "--") + ) logger.debug("INITIAL AGENT STATS") logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}") @@ -296,9 +296,9 @@ def generate_refined_answer( ) agent_refined_end_time = datetime.now() - if state["agent_refined_start_time"]: + if state.agent_refined_start_time: agent_refined_duration = ( - agent_refined_end_time - state["agent_refined_start_time"] + agent_refined_end_time - state.agent_refined_start_time ).total_seconds() else: agent_refined_duration = None diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py index 0ee387376..033ecae47 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py @@ -15,9 +15,9 @@ def ingest_initial_base_retrieval( logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---") - sub_question_retrieval_stats = state[ - "base_expanded_retrieval_result" - ].sub_question_retrieval_stats + sub_question_retrieval_stats = ( + state.base_expanded_retrieval_result.sub_question_retrieval_stats + ) if sub_question_retrieval_stats is None: sub_question_retrieval_stats = AgentChunkStats() else: @@ -30,11 +30,7 @@ def ingest_initial_base_retrieval( ) return ExpandedRetrievalUpdate( - original_question_retrieval_results=state[ - "base_expanded_retrieval_result" - ].expanded_queries_results, - all_original_question_documents=state[ - "base_expanded_retrieval_result" - ].all_documents, + original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results, + all_original_question_documents=state.base_expanded_retrieval_result.all_documents, original_question_retrieval_stats=sub_question_retrieval_stats, ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py index 90a21c206..b510cfcb4 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py @@ -17,7 +17,7 @@ def ingest_initial_sub_question_answers( logger.debug(f"--------{now_start}--------INGEST ANSWERS---") documents = [] - answer_results = state.get("answer_results", []) + answer_results = state.answer_results if hasattr(state, "answer_results") else [] for answer_result in answer_results: documents.extend(answer_result.documents) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py index 2a3384f6d..080fbd65a 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py @@ -18,7 +18,7 @@ def ingest_refined_answers( logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---") documents = [] - answer_results = state.get("answer_results", []) + answer_results = state.answer_results if hasattr(state, "answer_results") else [] for answer_result in answer_results: documents.extend(answer_result.documents) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py index 271571bb8..10069425c 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py @@ -53,7 +53,7 @@ def initial_sub_question_creation( history = build_history_prompt(agent_a_config.message_history) # Use the initial search results to inform the decomposition - sample_doc_str = state.get("sample_doc_str", "") + sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else "" if not chat_session_id or not primary_message_id: raise ValueError( diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py index 8cad703f3..9efb524e4 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py @@ -32,8 +32,8 @@ def refined_answer_decision( f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---" ) - if not agent_a_config.allow_refinement: + if agent_a_config.allow_refinement: return RequireRefinedAnswerUpdate(require_refined_answer=decision) else: - return RequireRefinedAnswerUpdate(require_refined_answer=not decision) + return RequireRefinedAnswerUpdate(require_refined_answer=False) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py index 942ab1224..5b3cde984 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py @@ -37,7 +37,7 @@ def refined_sub_question_creation( tool_name="agent_search_1", tool_args={ "query": agent_a_config.search_request.query, - "answer": state["initial_answer"], + "answer": state.initial_answer, }, ), ) @@ -49,16 +49,16 @@ def refined_sub_question_creation( agent_refined_start_time = datetime.now() question = agent_a_config.search_request.query - base_answer = state["initial_answer"] + base_answer = state.initial_answer history = build_history_prompt(agent_a_config.message_history) # get the entity term extraction dict and properly format it - entity_retlation_term_extractions = state["entity_retlation_term_extractions"] + entity_retlation_term_extractions = state.entity_retlation_term_extractions entity_term_extraction_str = format_entity_term_extraction( entity_retlation_term_extractions ) - initial_question_answers = state["decomp_answer_results"] + initial_question_answers = state.decomp_answer_results addressed_question_list = [ x.question for x in initial_question_answers if "yes" in x.quality.lower() diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/retrieval_consolidation.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/retrieval_consolidation.py new file mode 100644 index 000000000..a1cd8ee54 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/retrieval_consolidation.py @@ -0,0 +1,12 @@ +from datetime import datetime + +from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate +from onyx.agents.agent_search.deep_search_a.main.states import MainState + + +def retrieval_consolidation( + state: MainState, +) -> LoggerUpdate: + now_start = datetime.now() + + return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"]) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/states.py b/backend/onyx/agents/agent_search/deep_search_a/main/states.py index 9081f4a88..d7398edd8 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/states.py @@ -3,6 +3,8 @@ from operator import add from typing import Annotated from typing import TypedDict +from pydantic import BaseModel + from onyx.agents.agent_search.core_state import CoreState from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( ExpandedRetrievalResult, @@ -28,33 +30,36 @@ from onyx.agents.agent_search.shared_graph_utils.operators import ( ) from onyx.context.search.models import InferenceSection - ### States ### ## Update States -class RefinedAgentStartStats(TypedDict): - agent_refined_start_time: datetime | None +class LoggerUpdate(BaseModel): + log_messages: Annotated[list[str], add] = [] -class RefinedAgentEndStats(TypedDict): - agent_refined_end_time: datetime | None - agent_refined_metrics: AgentRefinedMetrics +class RefinedAgentStartStats(BaseModel): + agent_refined_start_time: datetime | None = None -class BaseDecompUpdateBase(TypedDict): - agent_start_time: datetime - initial_decomp_questions: list[str] +class RefinedAgentEndStats(BaseModel): + agent_refined_end_time: datetime | None = None + agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics() -class RoutingDecisionBase(TypedDict): - routing: str - sample_doc_str: str +class BaseDecompUpdateBase(BaseModel): + agent_start_time: datetime = datetime.now() + initial_decomp_questions: list[str] = [] -class RoutingDecision(RoutingDecisionBase): - log_messages: list[str] +class RoutingDecisionBase(BaseModel): + routing: str = "" + sample_doc_str: str = "" + + +class RoutingDecision(RoutingDecisionBase, LoggerUpdate): + pass class BaseDecompUpdate( @@ -63,66 +68,72 @@ class BaseDecompUpdate( pass -class InitialAnswerBASEUpdate(TypedDict): - initial_base_answer: str +class InitialAnswerBASEUpdate(BaseModel): + initial_base_answer: str = "" -class InitialAnswerUpdateBase(TypedDict): - initial_answer: str - initial_agent_stats: InitialAgentResultStats | None - generated_sub_questions: list[str] - agent_base_end_time: datetime - agent_base_metrics: AgentBaseMetrics | None +class InitialAnswerUpdateBase(BaseModel): + initial_answer: str = "" + initial_agent_stats: InitialAgentResultStats | None = None + generated_sub_questions: list[str] = [] + agent_base_end_time: datetime | None = None + agent_base_metrics: AgentBaseMetrics | None = None -class InitialAnswerUpdate(InitialAnswerUpdateBase): - log_messages: list[str] +class InitialAnswerUpdate(InitialAnswerUpdateBase, LoggerUpdate): + pass -class RefinedAnswerUpdateBase(TypedDict): - refined_answer: str - refined_agent_stats: RefinedAgentStats | None - refined_answer_quality: bool +class RefinedAnswerUpdateBase(BaseModel): + refined_answer: str = "" + refined_agent_stats: RefinedAgentStats | None = None + refined_answer_quality: bool = False class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase): pass -class InitialAnswerQualityUpdate(TypedDict): - initial_answer_quality: bool +class InitialAnswerQualityUpdate(BaseModel): + initial_answer_quality: bool = False -class RequireRefinedAnswerUpdate(TypedDict): - require_refined_answer: bool +class RequireRefinedAnswerUpdate(BaseModel): + require_refined_answer: bool = True -class DecompAnswersUpdate(TypedDict): - documents: Annotated[list[InferenceSection], dedup_inference_sections] +class DecompAnswersUpdate(BaseModel): + documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] decomp_answer_results: Annotated[ list[QuestionAnswerResults], dedup_question_answer_results - ] + ] = [] -class FollowUpDecompAnswersUpdate(TypedDict): - refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] - refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] +class FollowUpDecompAnswersUpdate(BaseModel): + refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = [] + refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = [] -class ExpandedRetrievalUpdate(TypedDict): +class ExpandedRetrievalUpdate(BaseModel): all_original_question_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] - original_question_retrieval_results: list[QueryResult] - original_question_retrieval_stats: AgentChunkStats + original_question_retrieval_results: list[QueryResult] = [] + original_question_retrieval_stats: AgentChunkStats = AgentChunkStats() -class EntityTermExtractionUpdate(TypedDict): - entity_retlation_term_extractions: EntityRelationshipTermExtraction +class EntityTermExtractionUpdateBase(BaseModel): + entity_retlation_term_extractions: EntityRelationshipTermExtraction = ( + EntityRelationshipTermExtraction() + ) -class FollowUpSubQuestionsUpdateBase(TypedDict): - refined_sub_questions: dict[int, FollowUpSubQuestion] +class EntityTermExtractionUpdate(EntityTermExtractionUpdateBase, LoggerUpdate): + pass + + +class FollowUpSubQuestionsUpdateBase(BaseModel): + refined_sub_questions: dict[int, FollowUpSubQuestion] = {} class FollowUpSubQuestionsUpdate( @@ -145,12 +156,13 @@ class MainInput(CoreState): class MainState( # This includes the core state MainInput, + LoggerUpdate, BaseDecompUpdateBase, InitialAnswerUpdateBase, InitialAnswerBASEUpdate, DecompAnswersUpdate, ExpandedRetrievalUpdate, - EntityTermExtractionUpdate, + EntityTermExtractionUpdateBase, InitialAnswerQualityUpdate, RequireRefinedAnswerUpdate, FollowUpSubQuestionsUpdateBase, diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index c3857dbb6..97e68fb26 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -22,7 +22,7 @@ class AgentSearchConfig: primary_llm: LLM fast_llm: LLM search_tool: SearchTool - use_agentic_search: bool = False + use_agentic_search: bool = True # For persisting agent search data chat_session_id: UUID | None = None @@ -37,13 +37,13 @@ class AgentSearchConfig: db_session: Session | None = None # Whether to perform initial search to inform decomposition - perform_initial_search_path_decision: bool = False + perform_initial_search_path_decision: bool = True # Whether to perform initial search to inform decomposition - perform_initial_search_decomposition: bool = False + perform_initial_search_decomposition: bool = True # Whether to allow creation of refinement questions (and entity extraction, etc.) - allow_refinement: bool = False + allow_refinement: bool = True # Message history for the current chat session message_history: list[PreviousMessage] | None = None diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index 4dc238d18..5761f112c 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -147,6 +147,7 @@ def run_graph( # TODO: add these to the environment config.perform_initial_search_path_decision = True config.perform_initial_search_decomposition = True + config.allow_refinement = True for event in _manage_async_event_streaming( compiled_graph=compiled_graph, config=config, graph_input=input @@ -176,9 +177,9 @@ def run_main_graph( ) -> AnswerStream: compiled_graph = load_compiled_graph(graph_name) if graph_name == "a": - input = MainInput_a() + input = MainInput_a(base_question=config.search_request.query, log_messages=[]) else: - input = MainInput_a() + input = MainInput_a(base_question=config.search_request.query, log_messages=[]) # Agent search is not a Tool per se, but this is helpful for the frontend yield ToolCallKickoff( @@ -238,9 +239,13 @@ if __name__ == "__main__": config.perform_initial_search_path_decision = True config.perform_initial_search_decomposition = True if GRAPH_NAME == "a": - input = MainInput_a() + input = MainInput_a( + base_question=config.search_request.query, log_messages=[] + ) else: - input = MainInput_a() + input = MainInput_a( + base_question=config.search_request.query, log_messages=[] + ) # with open("output.txt", "w") as f: tool_responses: list = [] for output in run_graph(compiled_graph, config, input): diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/models.py b/backend/onyx/agents/agent_search/shared_graph_utils/models.py index c38c3db82..11005a1c8 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/models.py @@ -40,12 +40,12 @@ class AgentChunkScores(BaseModel): class AgentChunkStats(BaseModel): - verified_count: int | None - verified_avg_scores: float | None - rejected_count: int | None - rejected_avg_scores: float | None - verified_doc_chunk_ids: list[str] - dismissed_doc_chunk_ids: list[str] + verified_count: int | None = None + verified_avg_scores: float | None = None + rejected_count: int | None = None + rejected_avg_scores: float | None = None + verified_doc_chunk_ids: list[str] = [] + dismissed_doc_chunk_ids: list[str] = [] class InitialAgentResultStats(BaseModel): @@ -60,29 +60,29 @@ class RefinedAgentStats(BaseModel): class Term(BaseModel): - term_name: str - term_type: str - term_similar_to: list[str] + term_name: str = "" + term_type: str = "" + term_similar_to: list[str] = [] ### Models ### class Entity(BaseModel): - entity_name: str - entity_type: str + entity_name: str = "" + entity_type: str = "" class Relationship(BaseModel): - relationship_name: str - relationship_type: str - relationship_entities: list[str] + relationship_name: str = "" + relationship_type: str = "" + relationship_entities: list[str] = [] class EntityRelationshipTermExtraction(BaseModel): - entities: list[Entity] - relationships: list[Relationship] - terms: list[Term] + entities: list[Entity] = [] + relationships: list[Relationship] = [] + terms: list[Term] = [] ### Models ###