pydantic for LangGraph + changed ERT extraction flow

This commit is contained in:
joachim-danswer 2025-01-22 09:39:07 -08:00 committed by Evan Lohn
parent b9bd2ea4e2
commit b7f9e431a5
45 changed files with 364 additions and 254 deletions

View File

@ -1,18 +1,19 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
class CoreState(TypedDict, total=False):
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
base_question: str
log_messages: Annotated[list[str], add]
base_question: str = ""
log_messages: Annotated[list[str], add] = []
class SubgraphCoreState(TypedDict, total=False):
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""

View File

@ -1,4 +1,5 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
@ -15,12 +16,14 @@ logger = setup_logger()
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval via edge")
now_start = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state["question"],
question=state.question,
base_search=False,
sub_question_id=state["question_id"],
sub_question_id=state.question_id,
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
),
)

View File

@ -115,6 +115,7 @@ if __name__ == "__main__":
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,

View File

@ -17,15 +17,15 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
if state["answer"] == UNKNOWN_ANSWER:
if state.answer == UNKNOWN_ANSWER:
return QACheckUpdate(
answer_quality=SUB_CHECK_NO,
)
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state["question"],
base_answer=state["answer"],
question=state.question,
base_answer=state.answer,
)
)
]

View File

@ -40,9 +40,9 @@ def answer_generation(
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state["question"]
docs = state["documents"]
level, question_nr = parse_question_id(state["question_id"])
question = state.question
docs = state.documents
level, question_nr = parse_question_id(state.question_id)
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
if len(docs) == 0:

View File

@ -13,13 +13,15 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state["question"],
question_id=state["question_id"],
quality=state.get("answer_quality", "No"),
answer=state["answer"],
expanded_retrieval_results=state["expanded_retrieval_results"],
documents=state["documents"],
sub_question_retrieval_stats=state["sub_question_retrieval_stats"],
question=state.question,
question_id=state.question_id,
quality=state.answer_quality
if hasattr(state, "answer_quality")
else "No",
answer=state.answer,
expanded_retrieval_results=state.expanded_retrieval_results,
documents=state.documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@ -8,16 +8,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
sub_question_retrieval_stats = state[
"expanded_retrieval_result"
].sub_question_retrieval_stats
sub_question_retrieval_stats = (
state.expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkStats()]
return RetrievalIngestionUpdate(
expanded_retrieval_results=state[
"expanded_retrieval_result"
].expanded_queries_results,
documents=state["expanded_retrieval_result"].all_documents,
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
documents=state.expanded_retrieval_result.all_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@ -1,6 +1,7 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
@ -15,27 +16,29 @@ from onyx.context.search.models import InferenceSection
## Update States
class QACheckUpdate(TypedDict):
answer_quality: str
class QACheckUpdate(BaseModel):
answer_quality: str = ""
class QAGenerationUpdate(TypedDict):
answer: str
class QAGenerationUpdate(BaseModel):
answer: str = ""
# answer_stat: AnswerStats
class RetrievalIngestionUpdate(TypedDict):
expanded_retrieval_results: list[QueryResult]
documents: Annotated[list[InferenceSection], dedup_inference_sections]
sub_question_retrieval_stats: AgentChunkStats
class RetrievalIngestionUpdate(BaseModel):
expanded_retrieval_results: list[QueryResult] = []
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
## Graph Input State
class AnswerQuestionInput(SubgraphCoreState):
question: str
question_id: str # 0_0 is original question, everything else is <level>_<question_num>.
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
@ -55,11 +58,11 @@ class AnswerQuestionState(
## Graph Output State
class AnswerQuestionOutput(TypedDict):
class AnswerQuestionOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[QuestionAnswerResults], add]
answer_results: Annotated[list[QuestionAnswerResults], add] = []

View File

@ -1,4 +1,5 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
@ -15,12 +16,13 @@ logger = setup_logger()
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state["question"],
sub_question_id=state["question_id"],
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@ -111,6 +111,7 @@ if __name__ == "__main__":
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,

View File

@ -12,7 +12,7 @@ logger = setup_logger()
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
logger.debug("format_raw_search_results")
return BaseRawSearchOutput(
base_expanded_retrieval_result=state["expanded_retrieval_result"],
# base_retrieval_results=[state["expanded_retrieval_result"]],
base_expanded_retrieval_result=state.expanded_retrieval_result,
# base_retrieval_results=[state.expanded_retrieval_result],
# base_search_documents=[],
)

View File

@ -21,4 +21,5 @@ def generate_raw_search_data(
question=agent_a_config.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@ -1,4 +1,4 @@
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
@ -21,7 +21,7 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
## Graph Output State
class BaseRawSearchOutput(TypedDict):
class BaseRawSearchOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
@ -30,7 +30,7 @@ class BaseRawSearchOutput(TypedDict):
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
base_expanded_retrieval_result: ExpandedRetrievalResult
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
## Graph State

View File

@ -17,9 +17,11 @@ def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.get("question", agent_a_config.search_request.query)
question = state.question if state.question else agent_a_config.search_request.query
query_expansions = state.get("expanded_queries", []) + [question]
query_expansions = (
state.expanded_queries if state.expanded_queries else [] + [question]
)
return [
Send(
"doc_retrieval",
@ -27,7 +29,8 @@ def parallel_retrieval_edge(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.get("sub_question_id"),
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions

View File

@ -14,6 +14,9 @@ from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retriev
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
doc_verification,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
dummy,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
@ -52,6 +55,11 @@ def expanded_retrieval_graph_builder() -> StateGraph:
action=expand_queries,
)
graph.add_node(
node="dummy",
action=dummy,
)
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
@ -78,9 +86,13 @@ def expanded_retrieval_graph_builder() -> StateGraph:
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="dummy",
)
graph.add_conditional_edges(
source="expand_queries",
source="dummy",
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
@ -124,6 +136,7 @@ if __name__ == "__main__":
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,

View File

@ -6,6 +6,6 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult]
all_documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats
expanded_queries_results: list[QueryResult] = []
all_documents: list[InferenceSection] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()

View File

@ -24,13 +24,13 @@ from onyx.db.engine import get_session_context_manager
def doc_reranking(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
verified_documents = state["verified_documents"]
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.get("question", agent_a_config.search_request.query)
question = state.question if state.question else agent_a_config.search_request.query
with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question),

View File

@ -35,7 +35,7 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval
expanded_retrieval_results: list[ExpandedRetrievalResult]
retrieved_documents: list[InferenceSection]
"""
query_to_retrieve = state["query_to_retrieve"]
query_to_retrieve = state.query_to_retrieve
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
search_tool = agent_a_config.search_tool

View File

@ -30,8 +30,8 @@ def doc_verification(
verified_documents: list[InferenceSection]
"""
question = state["question"]
doc_to_verify = state["doc_to_verify"]
question = state.question
doc_to_verify = state.doc_to_verify
document_content = doc_to_verify.combined_content
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])

View File

@ -0,0 +1,16 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
QueryExpansionUpdate,
)
def dummy(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@ -28,10 +28,14 @@ def expand_queries(
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.get("question", agent_a_config.search_request.query)
question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
llm = agent_a_config.fast_llm
chat_session_id = agent_a_config.chat_session_id
sub_question_id = state.get("sub_question_id")
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_nr = 0, 0
else:

View File

@ -25,10 +25,10 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
def format_results(
state: ExpandedRetrievalState, config: RunnableConfig
) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0")
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [
result.query_info
for result in state["expanded_retrieval_results"]
for result in state.expanded_retrieval_results
if result.query_info is not None
]
if len(query_infos) == 0:
@ -37,19 +37,15 @@ def format_results(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs
if not (level == 0 and question_nr == 0):
if len(state["reranked_documents"]) > 0:
stream_documents = state["reranked_documents"]
if len(state.reranked_documents) > 0:
stream_documents = state.reranked_documents
else:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
stream_documents = state["expanded_retrieval_results"][-1].search_results[
:3
]
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
for tool_response in yield_search_responses(
query=state["question"],
reranked_sections=state[
"retrieved_documents"
], # TODO: rename params. this one is supposed to be the sections pre-merging
query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
final_context_sections=stream_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance
@ -65,8 +61,8 @@ def format_results(
),
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state["verified_documents"],
expanded_retrieval_results=state["expanded_retrieval_results"],
verified_documents=state.verified_documents,
expanded_retrieval_results=state.expanded_retrieval_results,
)
if sub_question_retrieval_stats is None:
@ -76,8 +72,8 @@ def format_results(
return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state["expanded_retrieval_results"],
all_documents=state["reranked_documents"],
expanded_queries_results=state.expanded_retrieval_results,
all_documents=state.reranked_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@ -18,10 +18,14 @@ def verification_kickoff(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["doc_verification"]]:
documents = state["retrieved_documents"]
documents = state.retrieved_documents
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
verification_question = state.get("question", agent_a_config.search_request.query)
sub_question_id = state.get("sub_question_id")
verification_question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
@ -32,6 +36,7 @@ def verification_kickoff(
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for doc in documents

View File

@ -1,6 +1,7 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
@ -20,42 +21,44 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalInput(SubgraphCoreState):
question: str
base_search: bool
sub_question_id: str | None
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
## Update/Return States
class QueryExpansionUpdate(TypedDict):
expanded_queries: list[str]
class QueryExpansionUpdate(BaseModel):
expanded_queries: list[str] = ["aaa", "bbb"]
class DocVerificationUpdate(TypedDict):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocVerificationUpdate(BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
class DocRetrievalUpdate(TypedDict):
expanded_retrieval_results: Annotated[list[QueryResult], add]
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocRetrievalUpdate(BaseModel):
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
class DocRerankingUpdate(TypedDict):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
sub_question_retrieval_stats: RetrievalFitStats | None
class DocRerankingUpdate(BaseModel):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: RetrievalFitStats | None = None
class ExpandedRetrievalUpdate(TypedDict):
class ExpandedRetrievalUpdate(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(TypedDict):
expanded_retrieval_result: ExpandedRetrievalResult
base_expanded_retrieval_result: ExpandedRetrievalResult
class ExpandedRetrievalOutput(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
## Graph State
@ -81,4 +84,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str
query_to_retrieve: str = ""

View File

@ -22,7 +22,7 @@ logger = setup_logger()
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
if len(state["initial_decomp_questions"]) > 0:
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
@ -39,9 +39,10 @@ def parallelize_initial_sub_question_answering(
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[],
),
)
for question_nr, question in enumerate(state["initial_decomp_questions"])
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
@ -59,7 +60,7 @@ def parallelize_initial_sub_question_answering(
def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate,
) -> Literal["refined_sub_question_creation", "logging_node"]:
if state["require_refined_answer"]:
if state.require_refined_answer:
return "refined_sub_question_creation"
else:
return "logging_node"
@ -68,16 +69,17 @@ def continue_to_refined_answer_or_end(
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
if len(state["refined_sub_questions"]) > 0:
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question",
AnswerQuestionInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_nr),
log_messages=[],
),
)
for question_nr, question_data in state["refined_sub_questions"].items()
for question_nr, question_data in state.refined_sub_questions.items()
]
else:

View File

@ -65,6 +65,9 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision i
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
refined_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
retrieval_consolidation,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@ -153,6 +156,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
node="ingest_initial_retrieval",
action=ingest_initial_base_retrieval,
)
graph.add_node(
node="retrieval_consolidation",
action=retrieval_consolidation,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_question_answers,
@ -215,6 +224,21 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
end_key="ingest_initial_retrieval",
)
graph.add_edge(
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
end_key="retrieval_consolidation",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="entity_term_extraction_llm",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="LLM",
end_key=END,
@ -236,14 +260,14 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
)
graph.add_edge(
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="entity_term_extraction_llm",
)
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key="generate_initial_answer",
@ -327,7 +351,9 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput()
inputs = MainInput(
base_question=agent_a_config.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,

View File

@ -20,16 +20,16 @@ class AgentBaseMetrics(BaseModel):
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None
base_doc_boost_factor: float | None
support_boost_factor: float | None
duration__s: float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration__s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None
refined_question_boost_factor: float | None
duration__s: float | None
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration__s: float | None = None
class AgentAdditionalMetrics(BaseModel):

View File

@ -19,10 +19,10 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
logger.debug(f"--------{now_start}--------LOGGING NODE---")
agent_start_time = state["agent_start_time"]
agent_base_end_time = state["agent_base_end_time"]
agent_refined_start_time = state["agent_refined_start_time"] or None
agent_refined_end_time = state["agent_refined_end_time"] or None
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time or None
agent_refined_end_time = state.agent_refined_end_time or None
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
@ -41,8 +41,8 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state["agent_base_metrics"]
agent_refined_metrics = state["agent_refined_metrics"]
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
@ -81,7 +81,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
db_session = agent_a_config.db_session
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
sub_question_answer_results = state["decomp_answer_results"]
sub_question_answer_results = state.decomp_answer_results
log_agent_sub_question_results(
db_session=db_session,

View File

@ -17,6 +17,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import (
)
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.llm.utils import check_number_of_tokens
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@ -86,9 +87,13 @@ def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDeci
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
)
check_number_of_tokens(agent_decision_prompt)
return RoutingDecision(
# Decide which route to take
routing=routing,
sample_doc_str=sample_doc_str,
log_messages=[f"Path decision: {routing}, Time taken: {now_end - now_start}"],
log_messages=[
f"{now_start} -- Path decision: {routing}, Time taken: {now_end - now_start}"
],
)

View File

@ -8,7 +8,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import MainState
def agent_path_routing(
state: MainState,
) -> Command[Literal["agent_search_start", "LLM"]]:
routing = state.get("routing", "agent_search")
routing = state.routing if hasattr(state, "routing") else "agent_search"
if routing == "agent_search":
agent_path = "agent_search_start"

View File

@ -48,14 +48,13 @@ def entity_term_extraction_llm(
# first four lines duplicates from generate_initial_answer
question = agent_a_config.search_request.query
sub_question_docs = state["documents"]
all_original_question_documents = state["all_original_question_documents"]
sub_question_docs = state.documents
all_original_question_documents = state.all_original_question_documents
relevant_docs = dedup_inference_sections(
sub_question_docs, all_original_question_documents
)
# start with the entity/term/extraction
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
@ -127,5 +126,8 @@ def entity_term_extraction_llm(
entities=entities,
relationships=relationships,
terms=terms,
)
),
log_messages=[
f"{now_start} -- Entity Term Extraction - Time taken: {now_end - now_start}"
],
)

View File

@ -64,8 +64,8 @@ def generate_initial_answer(
history = build_history_prompt(agent_a_config.message_history)
sub_question_docs = state["documents"]
all_original_question_documents = state["all_original_question_documents"]
sub_question_docs = state.documents
all_original_question_documents = state.all_original_question_documents
relevant_docs = dedup_inference_sections(
sub_question_docs, all_original_question_documents
@ -92,7 +92,7 @@ def generate_initial_answer(
else:
# Use the query info from the base document retrieval
query_info = get_query_info(state["original_question_retrieval_results"])
query_info = get_query_info(state.original_question_retrieval_results)
for tool_response in yield_search_responses(
query=question,
@ -117,7 +117,7 @@ def generate_initial_answer(
if all_original_question_doc not in sub_question_docs:
net_new_original_question_docs.append(all_original_question_doc)
decomp_answer_results = state["decomp_answer_results"]
decomp_answer_results = state.decomp_answer_results
good_qa_list: list[str] = []
@ -206,7 +206,7 @@ def generate_initial_answer(
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state["decomp_answer_results"], state["original_question_retrieval_stats"]
state.decomp_answer_results, state.original_question_retrieval_stats
)
logger.debug(
@ -228,12 +228,8 @@ def generate_initial_answer(
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state[
"original_question_retrieval_stats"
].verified_count,
verified_avg_score_core=state[
"original_question_retrieval_stats"
].verified_avg_scores,
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents", None
),
@ -246,7 +242,7 @@ def generate_initial_answer(
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio", None
),
duration__s=(agent_base_end_time - state["agent_start_time"]).total_seconds(),
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
)
return InitialAnswerUpdate(
@ -255,5 +251,7 @@ def generate_initial_answer(
generated_sub_questions=decomp_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[f"Initial answer generation: {now_end - now_start}"],
log_messages=[
f"{now_start} -- Initial answer generation - Time taken: {now_end - now_start}"
],
)

View File

@ -25,7 +25,7 @@ def generate_initial_base_search_only_answer(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
original_question_docs = state["all_original_question_documents"]
original_question_docs = state.all_original_question_documents
model = agent_a_config.fast_llm

View File

@ -61,12 +61,12 @@ def generate_refined_answer(
history = build_history_prompt(agent_a_config.message_history)
initial_documents = state["documents"]
revised_documents = state["refined_documents"]
initial_documents = state.documents
revised_documents = state.refined_documents
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
query_info = get_query_info(state["original_question_retrieval_results"])
query_info = get_query_info(state.original_question_retrieval_results)
# stream refined answer docs
for tool_response in yield_search_responses(
query=question,
@ -93,8 +93,8 @@ def generate_refined_answer(
else:
revision_doc_effectiveness = 10.0
decomp_answer_results = state["decomp_answer_results"]
# revised_answer_results = state["refined_decomp_answer_results"]
decomp_answer_results = state.decomp_answer_results
# revised_answer_results = state.refined_decomp_answer_results
good_qa_list: list[str] = []
decomp_questions = []
@ -147,7 +147,7 @@ def generate_refined_answer(
# original answer
initial_answer = state["initial_answer"]
initial_answer = state.initial_answer
# Determine which persona-specification prompt to use
@ -218,7 +218,7 @@ def generate_refined_answer(
answer = cast(str, response)
# refined_agent_stats = _calculate_refined_agent_stats(
# state["decomp_answer_results"], state["original_question_retrieval_stats"]
# state.decomp_answer_results, state.original_question_retrieval_stats
# )
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
@ -252,22 +252,22 @@ def generate_refined_answer(
logger.debug("-" * 100)
if state["initial_agent_stats"]:
initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get(
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = state[
"initial_agent_stats"
].agent_effectiveness.get("support_ratio", "--")
num_initial_verified_docs = state["initial_agent_stats"].original_question.get(
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = state[
"initial_agent_stats"
].original_question.get("verified_avg_score", "--")
initial_sub_questions_verified_docs = state[
"initial_agent_stats"
].sub_questions.get("num_verified_documents", "--")
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
@ -296,9 +296,9 @@ def generate_refined_answer(
)
agent_refined_end_time = datetime.now()
if state["agent_refined_start_time"]:
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state["agent_refined_start_time"]
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None

View File

@ -15,9 +15,9 @@ def ingest_initial_base_retrieval(
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
sub_question_retrieval_stats = state[
"base_expanded_retrieval_result"
].sub_question_retrieval_stats
sub_question_retrieval_stats = (
state.base_expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
else:
@ -30,11 +30,7 @@ def ingest_initial_base_retrieval(
)
return ExpandedRetrievalUpdate(
original_question_retrieval_results=state[
"base_expanded_retrieval_result"
].expanded_queries_results,
all_original_question_documents=state[
"base_expanded_retrieval_result"
].all_documents,
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.base_expanded_retrieval_result.all_documents,
original_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@ -17,7 +17,7 @@ def ingest_initial_sub_question_answers(
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
documents = []
answer_results = state.get("answer_results", [])
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)

View File

@ -18,7 +18,7 @@ def ingest_refined_answers(
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
documents = []
answer_results = state.get("answer_results", [])
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)

View File

@ -53,7 +53,7 @@ def initial_sub_question_creation(
history = build_history_prompt(agent_a_config.message_history)
# Use the initial search results to inform the decomposition
sample_doc_str = state.get("sample_doc_str", "")
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
if not chat_session_id or not primary_message_id:
raise ValueError(

View File

@ -32,8 +32,8 @@ def refined_answer_decision(
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
)
if not agent_a_config.allow_refinement:
if agent_a_config.allow_refinement:
return RequireRefinedAnswerUpdate(require_refined_answer=decision)
else:
return RequireRefinedAnswerUpdate(require_refined_answer=not decision)
return RequireRefinedAnswerUpdate(require_refined_answer=False)

View File

@ -37,7 +37,7 @@ def refined_sub_question_creation(
tool_name="agent_search_1",
tool_args={
"query": agent_a_config.search_request.query,
"answer": state["initial_answer"],
"answer": state.initial_answer,
},
),
)
@ -49,16 +49,16 @@ def refined_sub_question_creation(
agent_refined_start_time = datetime.now()
question = agent_a_config.search_request.query
base_answer = state["initial_answer"]
base_answer = state.initial_answer
history = build_history_prompt(agent_a_config.message_history)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state["entity_retlation_term_extractions"]
entity_retlation_term_extractions = state.entity_retlation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state["decomp_answer_results"]
initial_question_answers = state.decomp_answer_results
addressed_question_list = [
x.question for x in initial_question_answers if "yes" in x.quality.lower()

View File

@ -0,0 +1,12 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def retrieval_consolidation(
state: MainState,
) -> LoggerUpdate:
now_start = datetime.now()
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])

View File

@ -3,6 +3,8 @@ from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
@ -28,33 +30,36 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
)
from onyx.context.search.models import InferenceSection
### States ###
## Update States
class RefinedAgentStartStats(TypedDict):
agent_refined_start_time: datetime | None
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentEndStats(TypedDict):
agent_refined_end_time: datetime | None
agent_refined_metrics: AgentRefinedMetrics
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class BaseDecompUpdateBase(TypedDict):
agent_start_time: datetime
initial_decomp_questions: list[str]
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class RoutingDecisionBase(TypedDict):
routing: str
sample_doc_str: str
class BaseDecompUpdateBase(BaseModel):
agent_start_time: datetime = datetime.now()
initial_decomp_questions: list[str] = []
class RoutingDecision(RoutingDecisionBase):
log_messages: list[str]
class RoutingDecisionBase(BaseModel):
routing: str = ""
sample_doc_str: str = ""
class RoutingDecision(RoutingDecisionBase, LoggerUpdate):
pass
class BaseDecompUpdate(
@ -63,66 +68,72 @@ class BaseDecompUpdate(
pass
class InitialAnswerBASEUpdate(TypedDict):
initial_base_answer: str
class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = ""
class InitialAnswerUpdateBase(TypedDict):
initial_answer: str
initial_agent_stats: InitialAgentResultStats | None
generated_sub_questions: list[str]
agent_base_end_time: datetime
agent_base_metrics: AgentBaseMetrics | None
class InitialAnswerUpdateBase(BaseModel):
initial_answer: str = ""
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class InitialAnswerUpdate(InitialAnswerUpdateBase):
log_messages: list[str]
class InitialAnswerUpdate(InitialAnswerUpdateBase, LoggerUpdate):
pass
class RefinedAnswerUpdateBase(TypedDict):
refined_answer: str
refined_agent_stats: RefinedAgentStats | None
refined_answer_quality: bool
class RefinedAnswerUpdateBase(BaseModel):
refined_answer: str = ""
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase):
pass
class InitialAnswerQualityUpdate(TypedDict):
initial_answer_quality: bool
class InitialAnswerQualityUpdate(BaseModel):
initial_answer_quality: bool = False
class RequireRefinedAnswerUpdate(TypedDict):
require_refined_answer: bool
class RequireRefinedAnswerUpdate(BaseModel):
require_refined_answer: bool = True
class DecompAnswersUpdate(TypedDict):
documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DecompAnswersUpdate(BaseModel):
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
decomp_answer_results: Annotated[
list[QuestionAnswerResults], dedup_question_answer_results
]
] = []
class FollowUpDecompAnswersUpdate(TypedDict):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections]
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add]
class FollowUpDecompAnswersUpdate(BaseModel):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
class ExpandedRetrievalUpdate(TypedDict):
class ExpandedRetrievalUpdate(BaseModel):
all_original_question_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
original_question_retrieval_results: list[QueryResult]
original_question_retrieval_stats: AgentChunkStats
original_question_retrieval_results: list[QueryResult] = []
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
class EntityTermExtractionUpdate(TypedDict):
entity_retlation_term_extractions: EntityRelationshipTermExtraction
class EntityTermExtractionUpdateBase(BaseModel):
entity_retlation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class FollowUpSubQuestionsUpdateBase(TypedDict):
refined_sub_questions: dict[int, FollowUpSubQuestion]
class EntityTermExtractionUpdate(EntityTermExtractionUpdateBase, LoggerUpdate):
pass
class FollowUpSubQuestionsUpdateBase(BaseModel):
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
class FollowUpSubQuestionsUpdate(
@ -145,12 +156,13 @@ class MainInput(CoreState):
class MainState(
# This includes the core state
MainInput,
LoggerUpdate,
BaseDecompUpdateBase,
InitialAnswerUpdateBase,
InitialAnswerBASEUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
EntityTermExtractionUpdate,
EntityTermExtractionUpdateBase,
InitialAnswerQualityUpdate,
RequireRefinedAnswerUpdate,
FollowUpSubQuestionsUpdateBase,

View File

@ -22,7 +22,7 @@ class AgentSearchConfig:
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool
use_agentic_search: bool = False
use_agentic_search: bool = True
# For persisting agent search data
chat_session_id: UUID | None = None
@ -37,13 +37,13 @@ class AgentSearchConfig:
db_session: Session | None = None
# Whether to perform initial search to inform decomposition
perform_initial_search_path_decision: bool = False
perform_initial_search_path_decision: bool = True
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = False
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = False
allow_refinement: bool = True
# Message history for the current chat session
message_history: list[PreviousMessage] | None = None

View File

@ -147,6 +147,7 @@ def run_graph(
# TODO: add these to the environment
config.perform_initial_search_path_decision = True
config.perform_initial_search_decomposition = True
config.allow_refinement = True
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
@ -176,9 +177,9 @@ def run_main_graph(
) -> AnswerStream:
compiled_graph = load_compiled_graph(graph_name)
if graph_name == "a":
input = MainInput_a()
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
else:
input = MainInput_a()
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@ -238,9 +239,13 @@ if __name__ == "__main__":
config.perform_initial_search_path_decision = True
config.perform_initial_search_decomposition = True
if GRAPH_NAME == "a":
input = MainInput_a()
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
else:
input = MainInput_a()
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
# with open("output.txt", "w") as f:
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):

View File

@ -40,12 +40,12 @@ class AgentChunkScores(BaseModel):
class AgentChunkStats(BaseModel):
verified_count: int | None
verified_avg_scores: float | None
rejected_count: int | None
rejected_avg_scores: float | None
verified_doc_chunk_ids: list[str]
dismissed_doc_chunk_ids: list[str]
verified_count: int | None = None
verified_avg_scores: float | None = None
rejected_count: int | None = None
rejected_avg_scores: float | None = None
verified_doc_chunk_ids: list[str] = []
dismissed_doc_chunk_ids: list[str] = []
class InitialAgentResultStats(BaseModel):
@ -60,29 +60,29 @@ class RefinedAgentStats(BaseModel):
class Term(BaseModel):
term_name: str
term_type: str
term_similar_to: list[str]
term_name: str = ""
term_type: str = ""
term_similar_to: list[str] = []
### Models ###
class Entity(BaseModel):
entity_name: str
entity_type: str
entity_name: str = ""
entity_type: str = ""
class Relationship(BaseModel):
relationship_name: str
relationship_type: str
relationship_entities: list[str]
relationship_name: str = ""
relationship_type: str = ""
relationship_entities: list[str] = []
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity]
relationships: list[Relationship]
terms: list[Term]
entities: list[Entity] = []
relationships: list[Relationship] = []
terms: list[Term] = []
### Models ###