mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
pydantic for LangGraph + changed ERT extraction flow
This commit is contained in:
parent
b9bd2ea4e2
commit
b7f9e431a5
@ -1,18 +1,19 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(TypedDict, total=False):
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str
|
||||
log_messages: Annotated[list[str], add]
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(TypedDict, total=False):
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
@ -15,12 +16,14 @@ logger = setup_logger()
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
now_start = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state["question"],
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state["question_id"],
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
|
@ -115,6 +115,7 @@ if __name__ == "__main__":
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
@ -17,15 +17,15 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
|
||||
if state["answer"] == UNKNOWN_ANSWER:
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state["question"],
|
||||
base_answer=state["answer"],
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
@ -40,9 +40,9 @@ def answer_generation(
|
||||
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state["question"]
|
||||
docs = state["documents"]
|
||||
level, question_nr = parse_question_id(state["question_id"])
|
||||
question = state.question
|
||||
docs = state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
|
||||
|
||||
if len(docs) == 0:
|
||||
|
@ -13,13 +13,15 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state["question"],
|
||||
question_id=state["question_id"],
|
||||
quality=state.get("answer_quality", "No"),
|
||||
answer=state["answer"],
|
||||
expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
documents=state["documents"],
|
||||
sub_question_retrieval_stats=state["sub_question_retrieval_stats"],
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -8,16 +8,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = state[
|
||||
"expanded_retrieval_result"
|
||||
].sub_question_retrieval_stats
|
||||
sub_question_retrieval_stats = (
|
||||
state.expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state[
|
||||
"expanded_retrieval_result"
|
||||
].expanded_queries_results,
|
||||
documents=state["expanded_retrieval_result"].all_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
documents=state.expanded_retrieval_result.all_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
@ -15,27 +16,29 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(TypedDict):
|
||||
answer_quality: str
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
|
||||
|
||||
class QAGenerationUpdate(TypedDict):
|
||||
answer: str
|
||||
class QAGenerationUpdate(BaseModel):
|
||||
answer: str = ""
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(TypedDict):
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
class RetrievalIngestionUpdate(BaseModel):
|
||||
expanded_retrieval_results: list[QueryResult] = []
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
@ -55,11 +58,11 @@ class AnswerQuestionState(
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(TypedDict):
|
||||
class AnswerQuestionOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
@ -15,12 +16,13 @@ logger = setup_logger()
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state["question"],
|
||||
sub_question_id=state["question_id"],
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
|
@ -111,6 +111,7 @@ if __name__ == "__main__":
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
@ -12,7 +12,7 @@ logger = setup_logger()
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state["expanded_retrieval_result"],
|
||||
# base_retrieval_results=[state["expanded_retrieval_result"]],
|
||||
base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# base_retrieval_results=[state.expanded_retrieval_result],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
|
@ -21,4 +21,5 @@ def generate_raw_search_data(
|
||||
question=agent_a_config.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TypedDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
@ -21,7 +21,7 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(TypedDict):
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
@ -30,7 +30,7 @@ class BaseRawSearchOutput(TypedDict):
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
@ -17,9 +17,11 @@ def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.get("question", agent_a_config.search_request.query)
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = state.get("expanded_queries", []) + [question]
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
@ -27,7 +29,8 @@ def parallel_retrieval_edge(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.get("sub_question_id"),
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
|
@ -14,6 +14,9 @@ from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retriev
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
|
||||
dummy,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
@ -52,6 +55,11 @@ def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy",
|
||||
action=dummy,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
@ -78,9 +86,13 @@ def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="dummy",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="expand_queries",
|
||||
source="dummy",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
@ -124,6 +136,7 @@ if __name__ == "__main__":
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
@ -6,6 +6,6 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult]
|
||||
all_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
expanded_queries_results: list[QueryResult] = []
|
||||
all_documents: list[InferenceSection] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
@ -24,13 +24,13 @@ from onyx.db.engine import get_session_context_manager
|
||||
def doc_reranking(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
verified_documents = state["verified_documents"]
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.get("question", agent_a_config.search_request.query)
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
with get_session_context_manager() as db_session:
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
|
@ -35,7 +35,7 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
query_to_retrieve = state["query_to_retrieve"]
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_a_config.search_tool
|
||||
|
||||
|
@ -30,8 +30,8 @@ def doc_verification(
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state["question"]
|
||||
doc_to_verify = state["doc_to_verify"]
|
||||
question = state.question
|
||||
doc_to_verify = state.doc_to_verify
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
|
@ -0,0 +1,16 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def dummy(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
@ -28,10 +28,14 @@ def expand_queries(
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.get("question", agent_a_config.search_request.query)
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.get("sub_question_id")
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
|
@ -25,10 +25,10 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0")
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state["expanded_retrieval_results"]
|
||||
for result in state.expanded_retrieval_results
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
@ -37,19 +37,15 @@ def format_results(
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
if not (level == 0 and question_nr == 0):
|
||||
if len(state["reranked_documents"]) > 0:
|
||||
stream_documents = state["reranked_documents"]
|
||||
if len(state.reranked_documents) > 0:
|
||||
stream_documents = state.reranked_documents
|
||||
else:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
stream_documents = state["expanded_retrieval_results"][-1].search_results[
|
||||
:3
|
||||
]
|
||||
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
for tool_response in yield_search_responses(
|
||||
query=state["question"],
|
||||
reranked_sections=state[
|
||||
"retrieved_documents"
|
||||
], # TODO: rename params. this one is supposed to be the sections pre-merging
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||
final_context_sections=stream_documents,
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
@ -65,8 +61,8 @@ def format_results(
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state["verified_documents"],
|
||||
expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
@ -76,8 +72,8 @@ def format_results(
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state["expanded_retrieval_results"],
|
||||
all_documents=state["reranked_documents"],
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
all_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
|
@ -18,10 +18,14 @@ def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state["retrieved_documents"]
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = state.get("question", agent_a_config.search_request.query)
|
||||
sub_question_id = state.get("sub_question_id")
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
@ -32,6 +36,7 @@ def verification_kickoff(
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
|
@ -1,6 +1,7 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
@ -20,42 +21,44 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str
|
||||
base_search: bool
|
||||
sub_question_id: str | None
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(TypedDict):
|
||||
expanded_queries: list[str]
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
|
||||
|
||||
class DocVerificationUpdate(TypedDict):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(TypedDict):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add]
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
class DocRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(TypedDict):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None
|
||||
class DocRerankingUpdate(BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(TypedDict):
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
class ExpandedRetrievalOutput(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
@ -81,4 +84,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str
|
||||
query_to_retrieve: str = ""
|
||||
|
@ -22,7 +22,7 @@ logger = setup_logger()
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
if len(state["initial_decomp_questions"]) > 0:
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
@ -39,9 +39,10 @@ def parallelize_initial_sub_question_answering(
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state["initial_decomp_questions"])
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
@ -59,7 +60,7 @@ def parallelize_initial_sub_question_answering(
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
if state["require_refined_answer"]:
|
||||
if state.require_refined_answer:
|
||||
return "refined_sub_question_creation"
|
||||
else:
|
||||
return "logging_node"
|
||||
@ -68,16 +69,17 @@ def continue_to_refined_answer_or_end(
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
if len(state["refined_sub_questions"]) > 0:
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question",
|
||||
AnswerQuestionInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state["refined_sub_questions"].items()
|
||||
for question_nr, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
|
@ -65,6 +65,9 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision i
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
|
||||
retrieval_consolidation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@ -153,6 +156,12 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="retrieval_consolidation",
|
||||
action=retrieval_consolidation,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_question_answers,
|
||||
@ -215,6 +224,21 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
end_key="retrieval_consolidation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="LLM",
|
||||
end_key=END,
|
||||
@ -236,14 +260,14 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
@ -327,7 +351,9 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput()
|
||||
inputs = MainInput(
|
||||
base_question=agent_a_config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
@ -20,16 +20,16 @@ class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None
|
||||
base_doc_boost_factor: float | None
|
||||
support_boost_factor: float | None
|
||||
duration__s: float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None
|
||||
refined_question_boost_factor: float | None
|
||||
duration__s: float | None
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
|
@ -19,10 +19,10 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
|
||||
logger.debug(f"--------{now_start}--------LOGGING NODE---")
|
||||
|
||||
agent_start_time = state["agent_start_time"]
|
||||
agent_base_end_time = state["agent_base_end_time"]
|
||||
agent_refined_start_time = state["agent_refined_start_time"] or None
|
||||
agent_refined_end_time = state["agent_refined_end_time"] or None
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time or None
|
||||
agent_refined_end_time = state.agent_refined_end_time or None
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
@ -41,8 +41,8 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state["agent_base_metrics"]
|
||||
agent_refined_metrics = state["agent_refined_metrics"]
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
@ -81,7 +81,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
sub_question_answer_results = state["decomp_answer_results"]
|
||||
sub_question_answer_results = state.decomp_answer_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
|
@ -17,6 +17,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@ -86,9 +87,13 @@ def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDeci
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
|
||||
check_number_of_tokens(agent_decision_prompt)
|
||||
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing=routing,
|
||||
sample_doc_str=sample_doc_str,
|
||||
log_messages=[f"Path decision: {routing}, Time taken: {now_end - now_start}"],
|
||||
log_messages=[
|
||||
f"{now_start} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
routing = state.get("routing", "agent_search")
|
||||
routing = state.routing if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
|
@ -48,14 +48,13 @@ def entity_term_extraction_llm(
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_a_config.search_request.query
|
||||
sub_question_docs = state["documents"]
|
||||
all_original_question_documents = state["all_original_question_documents"]
|
||||
sub_question_docs = state.documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
|
||||
# start with the entity/term/extraction
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
@ -127,5 +126,8 @@ def entity_term_extraction_llm(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
)
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_start} -- Entity Term Extraction - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
@ -64,8 +64,8 @@ def generate_initial_answer(
|
||||
|
||||
history = build_history_prompt(agent_a_config.message_history)
|
||||
|
||||
sub_question_docs = state["documents"]
|
||||
all_original_question_documents = state["all_original_question_documents"]
|
||||
sub_question_docs = state.documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
@ -92,7 +92,7 @@ def generate_initial_answer(
|
||||
|
||||
else:
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state["original_question_retrieval_results"])
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
@ -117,7 +117,7 @@ def generate_initial_answer(
|
||||
if all_original_question_doc not in sub_question_docs:
|
||||
net_new_original_question_docs.append(all_original_question_doc)
|
||||
|
||||
decomp_answer_results = state["decomp_answer_results"]
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
@ -206,7 +206,7 @@ def generate_initial_answer(
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state["decomp_answer_results"], state["original_question_retrieval_stats"]
|
||||
state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -228,12 +228,8 @@ def generate_initial_answer(
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state[
|
||||
"original_question_retrieval_stats"
|
||||
].verified_count,
|
||||
verified_avg_score_core=state[
|
||||
"original_question_retrieval_stats"
|
||||
].verified_avg_scores,
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
),
|
||||
@ -246,7 +242,7 @@ def generate_initial_answer(
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
),
|
||||
duration__s=(agent_base_end_time - state["agent_start_time"]).total_seconds(),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
@ -255,5 +251,7 @@ def generate_initial_answer(
|
||||
generated_sub_questions=decomp_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[f"Initial answer generation: {now_end - now_start}"],
|
||||
log_messages=[
|
||||
f"{now_start} -- Initial answer generation - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ def generate_initial_base_search_only_answer(
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state["all_original_question_documents"]
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
|
@ -61,12 +61,12 @@ def generate_refined_answer(
|
||||
|
||||
history = build_history_prompt(agent_a_config.message_history)
|
||||
|
||||
initial_documents = state["documents"]
|
||||
revised_documents = state["refined_documents"]
|
||||
initial_documents = state.documents
|
||||
revised_documents = state.refined_documents
|
||||
|
||||
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
||||
|
||||
query_info = get_query_info(state["original_question_retrieval_results"])
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
# stream refined answer docs
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
@ -93,8 +93,8 @@ def generate_refined_answer(
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state["decomp_answer_results"]
|
||||
# revised_answer_results = state["refined_decomp_answer_results"]
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
@ -147,7 +147,7 @@ def generate_refined_answer(
|
||||
|
||||
# original answer
|
||||
|
||||
initial_answer = state["initial_answer"]
|
||||
initial_answer = state.initial_answer
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
@ -218,7 +218,7 @@ def generate_refined_answer(
|
||||
answer = cast(str, response)
|
||||
|
||||
# refined_agent_stats = _calculate_refined_agent_stats(
|
||||
# state["decomp_answer_results"], state["original_question_retrieval_stats"]
|
||||
# state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
# )
|
||||
|
||||
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
|
||||
@ -252,22 +252,22 @@ def generate_refined_answer(
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state["initial_agent_stats"]:
|
||||
initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get(
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = state[
|
||||
"initial_agent_stats"
|
||||
].agent_effectiveness.get("support_ratio", "--")
|
||||
num_initial_verified_docs = state["initial_agent_stats"].original_question.get(
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = state[
|
||||
"initial_agent_stats"
|
||||
].original_question.get("verified_avg_score", "--")
|
||||
initial_sub_questions_verified_docs = state[
|
||||
"initial_agent_stats"
|
||||
].sub_questions.get("num_verified_documents", "--")
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
@ -296,9 +296,9 @@ def generate_refined_answer(
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state["agent_refined_start_time"]:
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state["agent_refined_start_time"]
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
@ -15,9 +15,9 @@ def ingest_initial_base_retrieval(
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = state[
|
||||
"base_expanded_retrieval_result"
|
||||
].sub_question_retrieval_stats
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
@ -30,11 +30,7 @@ def ingest_initial_base_retrieval(
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state[
|
||||
"base_expanded_retrieval_result"
|
||||
].expanded_queries_results,
|
||||
all_original_question_documents=state[
|
||||
"base_expanded_retrieval_result"
|
||||
].all_documents,
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.all_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ def ingest_initial_sub_question_answers(
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
documents = []
|
||||
answer_results = state.get("answer_results", [])
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
|
@ -18,7 +18,7 @@ def ingest_refined_answers(
|
||||
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
|
||||
|
||||
documents = []
|
||||
answer_results = state.get("answer_results", [])
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
|
@ -53,7 +53,7 @@ def initial_sub_question_creation(
|
||||
history = build_history_prompt(agent_a_config.message_history)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.get("sample_doc_str", "")
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
|
@ -32,8 +32,8 @@ def refined_answer_decision(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
|
||||
)
|
||||
|
||||
if not agent_a_config.allow_refinement:
|
||||
if agent_a_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(require_refined_answer=decision)
|
||||
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(require_refined_answer=not decision)
|
||||
return RequireRefinedAnswerUpdate(require_refined_answer=False)
|
||||
|
@ -37,7 +37,7 @@ def refined_sub_question_creation(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_a_config.search_request.query,
|
||||
"answer": state["initial_answer"],
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
)
|
||||
@ -49,16 +49,16 @@ def refined_sub_question_creation(
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_a_config.search_request.query
|
||||
base_answer = state["initial_answer"]
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_a_config.message_history)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state["entity_retlation_term_extractions"]
|
||||
entity_retlation_term_extractions = state.entity_retlation_term_extractions
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_retlation_term_extractions
|
||||
)
|
||||
|
||||
initial_question_answers = state["decomp_answer_results"]
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
|
@ -0,0 +1,12 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def retrieval_consolidation(
|
||||
state: MainState,
|
||||
) -> LoggerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])
|
@ -3,6 +3,8 @@ from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
@ -28,33 +30,36 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class RefinedAgentStartStats(TypedDict):
|
||||
agent_refined_start_time: datetime | None
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentEndStats(TypedDict):
|
||||
agent_refined_end_time: datetime | None
|
||||
agent_refined_metrics: AgentRefinedMetrics
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class BaseDecompUpdateBase(TypedDict):
|
||||
agent_start_time: datetime
|
||||
initial_decomp_questions: list[str]
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class RoutingDecisionBase(TypedDict):
|
||||
routing: str
|
||||
sample_doc_str: str
|
||||
class BaseDecompUpdateBase(BaseModel):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class RoutingDecision(RoutingDecisionBase):
|
||||
log_messages: list[str]
|
||||
class RoutingDecisionBase(BaseModel):
|
||||
routing: str = ""
|
||||
sample_doc_str: str = ""
|
||||
|
||||
|
||||
class RoutingDecision(RoutingDecisionBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class BaseDecompUpdate(
|
||||
@ -63,66 +68,72 @@ class BaseDecompUpdate(
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(TypedDict):
|
||||
initial_base_answer: str
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
|
||||
|
||||
class InitialAnswerUpdateBase(TypedDict):
|
||||
initial_answer: str
|
||||
initial_agent_stats: InitialAgentResultStats | None
|
||||
generated_sub_questions: list[str]
|
||||
agent_base_end_time: datetime
|
||||
agent_base_metrics: AgentBaseMetrics | None
|
||||
class InitialAnswerUpdateBase(BaseModel):
|
||||
initial_answer: str = ""
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class InitialAnswerUpdate(InitialAnswerUpdateBase):
|
||||
log_messages: list[str]
|
||||
class InitialAnswerUpdate(InitialAnswerUpdateBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class RefinedAnswerUpdateBase(TypedDict):
|
||||
refined_answer: str
|
||||
refined_agent_stats: RefinedAgentStats | None
|
||||
refined_answer_quality: bool
|
||||
class RefinedAnswerUpdateBase(BaseModel):
|
||||
refined_answer: str = ""
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase):
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(TypedDict):
|
||||
initial_answer_quality: bool
|
||||
class InitialAnswerQualityUpdate(BaseModel):
|
||||
initial_answer_quality: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(TypedDict):
|
||||
require_refined_answer: bool
|
||||
class RequireRefinedAnswerUpdate(BaseModel):
|
||||
require_refined_answer: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(TypedDict):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
class DecompAnswersUpdate(BaseModel):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
]
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(TypedDict):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
class FollowUpDecompAnswersUpdate(BaseModel):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(TypedDict):
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult]
|
||||
original_question_retrieval_stats: AgentChunkStats
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(TypedDict):
|
||||
entity_retlation_term_extractions: EntityRelationshipTermExtraction
|
||||
class EntityTermExtractionUpdateBase(BaseModel):
|
||||
entity_retlation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdateBase(TypedDict):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion]
|
||||
class EntityTermExtractionUpdate(EntityTermExtractionUpdateBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdateBase(BaseModel):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(
|
||||
@ -145,12 +156,13 @@ class MainInput(CoreState):
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
LoggerUpdate,
|
||||
BaseDecompUpdateBase,
|
||||
InitialAnswerUpdateBase,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
EntityTermExtractionUpdateBase,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdateBase,
|
||||
|
@ -22,7 +22,7 @@ class AgentSearchConfig:
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool
|
||||
use_agentic_search: bool = False
|
||||
use_agentic_search: bool = True
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
@ -37,13 +37,13 @@ class AgentSearchConfig:
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_path_decision: bool = False
|
||||
perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = False
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = False
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Message history for the current chat session
|
||||
message_history: list[PreviousMessage] | None = None
|
||||
|
@ -147,6 +147,7 @@ def run_graph(
|
||||
# TODO: add these to the environment
|
||||
config.perform_initial_search_path_decision = True
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
@ -176,9 +177,9 @@ def run_main_graph(
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph(graph_name)
|
||||
if graph_name == "a":
|
||||
input = MainInput_a()
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
else:
|
||||
input = MainInput_a()
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@ -238,9 +239,13 @@ if __name__ == "__main__":
|
||||
config.perform_initial_search_path_decision = True
|
||||
config.perform_initial_search_decomposition = True
|
||||
if GRAPH_NAME == "a":
|
||||
input = MainInput_a()
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
else:
|
||||
input = MainInput_a()
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
|
@ -40,12 +40,12 @@ class AgentChunkScores(BaseModel):
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None
|
||||
verified_avg_scores: float | None
|
||||
rejected_count: int | None
|
||||
rejected_avg_scores: float | None
|
||||
verified_doc_chunk_ids: list[str]
|
||||
dismissed_doc_chunk_ids: list[str]
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
@ -60,29 +60,29 @@ class RefinedAgentStats(BaseModel):
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str
|
||||
term_type: str
|
||||
term_similar_to: list[str]
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str
|
||||
entity_type: str
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str
|
||||
relationship_type: str
|
||||
relationship_entities: list[str]
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity]
|
||||
relationships: list[Relationship]
|
||||
terms: list[Term]
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
Loading…
x
Reference in New Issue
Block a user