addressed TODOs

This commit is contained in:
Evan Lohn
2025-01-30 18:41:13 -08:00
parent a340529de3
commit 385b344a43
6 changed files with 41 additions and 26 deletions

View File

@@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
@@ -91,12 +92,14 @@ def generate_initial_answer(
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool,
):
write_custom_event(

View File

@@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
@@ -95,12 +96,13 @@ def generate_refined_answer(
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool,
):
write_custom_event(

View File

@@ -141,15 +141,14 @@ def calculate_initial_agent_stats(
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# TODO: see if this is the right way to do this
query_infos = [
result.query_info for result in results if result.query_info is not None
]
if len(query_infos) == 0:
return SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
raise ValueError("No query info found")
return query_infos[0]
# this is used for some fields that are the same across the searches done
query_info = None
for result in results:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)

View File

@@ -3,6 +3,7 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
@@ -18,6 +19,7 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
@@ -29,13 +31,7 @@ def format_results(
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [
result.query_info
for result in state.expanded_retrieval_results
if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
query_info = get_query_info(state.expanded_retrieval_results)
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs
@@ -50,12 +46,14 @@ def format_results(
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=agent_a_config.search_tool,
):
write_custom_event(

View File

@@ -80,7 +80,6 @@ def rerank_documents(
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
now_end = datetime.now()
logger.info(
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"

View File

@@ -37,6 +37,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -60,7 +61,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
BaseMessage_Content = str | list[str | dict[str, Any]]
@@ -391,3 +391,17 @@ def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter
) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
def relevance_from_docs(
relevant_docs: list[InferenceSection],
) -> list[SectionRelevancePiece]:
return [
SectionRelevancePiece(
relevant=True,
content=doc.center_chunk.content,
document_id=doc.center_chunk.document_id,
chunk_id=doc.center_chunk.chunk_id,
)
for doc in relevant_docs
]