mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 20:08:38 +02:00
addressed TODOs
This commit is contained in:
@@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
|||||||
dispatch_main_answer_stop_info,
|
dispatch_main_answer_stop_info,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||||
from onyx.chat.models import AgentAnswerPiece
|
from onyx.chat.models import AgentAnswerPiece
|
||||||
from onyx.chat.models import ExtendedToolResponse
|
from onyx.chat.models import ExtendedToolResponse
|
||||||
@@ -91,12 +92,14 @@ def generate_initial_answer(
|
|||||||
|
|
||||||
if agent_a_config.search_tool is None:
|
if agent_a_config.search_tool is None:
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
|
|
||||||
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
reranked_sections=relevant_docs,
|
reranked_sections=relevant_docs,
|
||||||
final_context_sections=relevant_docs,
|
final_context_sections=relevant_docs,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: None, # TODO: add relevance
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_a_config.search_tool,
|
search_tool=agent_a_config.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
|
@@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
|||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||||
from onyx.chat.models import AgentAnswerPiece
|
from onyx.chat.models import AgentAnswerPiece
|
||||||
from onyx.chat.models import ExtendedToolResponse
|
from onyx.chat.models import ExtendedToolResponse
|
||||||
@@ -95,12 +96,13 @@ def generate_refined_answer(
|
|||||||
if agent_a_config.search_tool is None:
|
if agent_a_config.search_tool is None:
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
# stream refined answer docs
|
# stream refined answer docs
|
||||||
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
reranked_sections=relevant_docs,
|
reranked_sections=relevant_docs,
|
||||||
final_context_sections=relevant_docs,
|
final_context_sections=relevant_docs,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: None, # TODO: add relevance
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_a_config.search_tool,
|
search_tool=agent_a_config.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
|
@@ -141,15 +141,14 @@ def calculate_initial_agent_stats(
|
|||||||
|
|
||||||
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||||
# Use the query info from the base document retrieval
|
# Use the query info from the base document retrieval
|
||||||
# TODO: see if this is the right way to do this
|
# this is used for some fields that are the same across the searches done
|
||||||
query_infos = [
|
query_info = None
|
||||||
result.query_info for result in results if result.query_info is not None
|
for result in results:
|
||||||
]
|
if result.query_info is not None:
|
||||||
if len(query_infos) == 0:
|
query_info = result.query_info
|
||||||
return SearchQueryInfo(
|
break
|
||||||
predicted_search=None,
|
return query_info or SearchQueryInfo(
|
||||||
final_filters=IndexFilters(access_control_list=None),
|
predicted_search=None,
|
||||||
recency_bias_multiplier=1.0,
|
final_filters=IndexFilters(access_control_list=None),
|
||||||
)
|
recency_bias_multiplier=1.0,
|
||||||
raise ValueError("No query info found")
|
)
|
||||||
return query_infos[0]
|
|
||||||
|
@@ -3,6 +3,7 @@ from typing import cast
|
|||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||||
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
|
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
|
||||||
ExpandedRetrievalResult,
|
ExpandedRetrievalResult,
|
||||||
)
|
)
|
||||||
@@ -18,6 +19,7 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp
|
|||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||||
from onyx.chat.models import ExtendedToolResponse
|
from onyx.chat.models import ExtendedToolResponse
|
||||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||||
@@ -29,13 +31,7 @@ def format_results(
|
|||||||
writer: StreamWriter = lambda _: None,
|
writer: StreamWriter = lambda _: None,
|
||||||
) -> ExpandedRetrievalUpdate:
|
) -> ExpandedRetrievalUpdate:
|
||||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||||
query_infos = [
|
query_info = get_query_info(state.expanded_retrieval_results)
|
||||||
result.query_info
|
|
||||||
for result in state.expanded_retrieval_results
|
|
||||||
if result.query_info is not None
|
|
||||||
]
|
|
||||||
if len(query_infos) == 0:
|
|
||||||
raise ValueError("No query info found")
|
|
||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||||
@@ -50,12 +46,14 @@ def format_results(
|
|||||||
|
|
||||||
if agent_a_config.search_tool is None:
|
if agent_a_config.search_tool is None:
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
|
|
||||||
|
relevance_list = relevance_from_docs(reranked_documents)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=state.question,
|
query=state.question,
|
||||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
reranked_sections=state.retrieved_documents,
|
||||||
final_context_sections=reranked_documents,
|
final_context_sections=reranked_documents,
|
||||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: None, # TODO: add relevance
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_a_config.search_tool,
|
search_tool=agent_a_config.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
|
@@ -80,7 +80,6 @@ def rerank_documents(
|
|||||||
else:
|
else:
|
||||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||||
|
|
||||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
|
||||||
now_end = datetime.now()
|
now_end = datetime.now()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||||
|
@@ -37,6 +37,7 @@ from onyx.chat.models import AnswerStyleConfig
|
|||||||
from onyx.chat.models import CitationConfig
|
from onyx.chat.models import CitationConfig
|
||||||
from onyx.chat.models import DocumentPruningConfig
|
from onyx.chat.models import DocumentPruningConfig
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
|
from onyx.chat.models import SectionRelevancePiece
|
||||||
from onyx.chat.models import StreamStopInfo
|
from onyx.chat.models import StreamStopInfo
|
||||||
from onyx.chat.models import StreamStopReason
|
from onyx.chat.models import StreamStopReason
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
@@ -60,7 +61,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
|||||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
|
|
||||||
|
|
||||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
@@ -391,3 +391,17 @@ def write_custom_event(
|
|||||||
name: str, event: AnswerPacket, stream_writer: StreamWriter
|
name: str, event: AnswerPacket, stream_writer: StreamWriter
|
||||||
) -> None:
|
) -> None:
|
||||||
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
|
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
|
||||||
|
|
||||||
|
|
||||||
|
def relevance_from_docs(
|
||||||
|
relevant_docs: list[InferenceSection],
|
||||||
|
) -> list[SectionRelevancePiece]:
|
||||||
|
return [
|
||||||
|
SectionRelevancePiece(
|
||||||
|
relevant=True,
|
||||||
|
content=doc.center_chunk.content,
|
||||||
|
document_id=doc.center_chunk.document_id,
|
||||||
|
chunk_id=doc.center_chunk.chunk_id,
|
||||||
|
)
|
||||||
|
for doc in relevant_docs
|
||||||
|
]
|
||||||
|
Reference in New Issue
Block a user