addressed TODOs

This commit is contained in:
Evan Lohn
2025-01-30 18:41:13 -08:00
parent a340529de3
commit 385b344a43
6 changed files with 41 additions and 26 deletions

View File

@@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info, dispatch_main_answer_stop_info,
) )
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import ExtendedToolResponse
@@ -91,12 +92,14 @@ def generate_initial_answer(
if agent_a_config.search_tool is None: if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,
reranked_sections=relevant_docs, reranked_sections=relevant_docs,
final_context_sections=relevant_docs, final_context_sections=relevant_docs,
search_query_info=query_info, search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool, search_tool=agent_a_config.search_tool,
): ):
write_custom_event( write_custom_event(

View File

@@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
) )
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import ExtendedToolResponse
@@ -95,12 +96,13 @@ def generate_refined_answer(
if agent_a_config.search_tool is None: if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs # stream refined answer docs
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,
reranked_sections=relevant_docs, reranked_sections=relevant_docs,
final_context_sections=relevant_docs, final_context_sections=relevant_docs,
search_query_info=query_info, search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool, search_tool=agent_a_config.search_tool,
): ):
write_custom_event( write_custom_event(

View File

@@ -141,15 +141,14 @@ def calculate_initial_agent_stats(
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo: def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval # Use the query info from the base document retrieval
# TODO: see if this is the right way to do this # this is used for some fields that are the same across the searches done
query_infos = [ query_info = None
result.query_info for result in results if result.query_info is not None for result in results:
] if result.query_info is not None:
if len(query_infos) == 0: query_info = result.query_info
return SearchQueryInfo( break
predicted_search=None, return query_info or SearchQueryInfo(
final_filters=IndexFilters(access_control_list=None), predicted_search=None,
recency_bias_multiplier=1.0, final_filters=IndexFilters(access_control_list=None),
) recency_bias_multiplier=1.0,
raise ValueError("No query info found") )
return query_infos[0]

View File

@@ -3,6 +3,7 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import ( from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
ExpandedRetrievalResult, ExpandedRetrievalResult,
) )
@@ -18,6 +19,7 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
@@ -29,13 +31,7 @@ def format_results(
writer: StreamWriter = lambda _: None, writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate: ) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0") level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [ query_info = get_query_info(state.expanded_retrieval_results)
result.query_info
for result in state.expanded_retrieval_results
if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs # main question docs will be sent later after aggregation and deduping with sub-question docs
@@ -50,12 +46,14 @@ def format_results(
if agent_a_config.search_tool is None: if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=state.question, query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.) reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents, final_context_sections=reranked_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos? search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool, search_tool=agent_a_config.search_tool,
): ):
write_custom_event( write_custom_event(

View File

@@ -80,7 +80,6 @@ def rerank_documents(
else: else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
now_end = datetime.now() now_end = datetime.now()
logger.info( logger.info(
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}" f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"

View File

@@ -37,6 +37,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -60,7 +61,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
BaseMessage_Content = str | list[str | dict[str, Any]] BaseMessage_Content = str | list[str | dict[str, Any]]
@@ -391,3 +391,17 @@ def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter name: str, event: AnswerPacket, stream_writer: StreamWriter
) -> None: ) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event)) stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
def relevance_from_docs(
relevant_docs: list[InferenceSection],
) -> list[SectionRelevancePiece]:
return [
SectionRelevancePiece(
relevant=True,
content=doc.center_chunk.content,
document_id=doc.center_chunk.document_id,
chunk_id=doc.center_chunk.chunk_id,
)
for doc in relevant_docs
]