addressing PR comments

This commit is contained in:
Evan Lohn 2025-01-30 15:36:11 -08:00
parent 2b8cd63b34
commit bb6d55783e
7 changed files with 45 additions and 29 deletions

View File

@ -85,7 +85,6 @@ if __name__ == "__main__":
graph = basic_graph_builder() graph = basic_graph_builder()
compiled_graph = graph.compile() compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(logs="") input = BasicInput(logs="")
primary_llm, fast_llm = get_default_llms() primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:

View File

@ -16,8 +16,9 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
class BasicInput(BaseModel): class BasicInput(BaseModel):
# TODO: subclass global log update state # Langgraph needs a nonempty input, but we pass in all static
logs: str = "" # data through a RunnableConfig.
_unused: bool = True
## Graph Output State ## Graph Output State

View File

@ -16,17 +16,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
# TODO: handle citations here; below is what was previously passed in
# see basic_use_tool_response.py for where these variables come from
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
def process_llm_stream( def process_llm_stream(
stream: Iterator[BaseMessage], messages: Iterator[BaseMessage],
should_stream_answer: bool, should_stream_answer: bool,
final_search_results: list[LlmDoc] | None = None, final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None, displayed_search_results: list[LlmDoc] | None = None,
@ -46,20 +38,20 @@ def process_llm_stream(
full_answer = "" full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen, # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information. # the stream will contain AIMessageChunks with tool call information.
for response in stream: for message in messages:
answer_piece = response.content answer_piece = message.content
if not isinstance(answer_piece, str): if not isinstance(answer_piece, str):
# TODO: handle non-string content # this is only used for logging, so fine to
logger.warning(f"Received non-string content: {type(answer_piece)}") # just add the string representation
answer_piece = str(answer_piece) answer_piece = str(answer_piece)
full_answer += answer_piece full_answer += answer_piece
if isinstance(response, AIMessageChunk) and ( if isinstance(message, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls message.tool_call_chunks or message.tool_calls
): ):
tool_call_chunk += response # type: ignore tool_call_chunk += message # type: ignore
elif should_stream_answer: elif should_stream_answer:
for response_part in answer_handler.handle_response_part(response, []): for response_part in answer_handler.handle_response_part(message, []):
dispatch_custom_event( dispatch_custom_event(
"basic_response", "basic_response",
response_part, response_part,

View File

@ -87,6 +87,7 @@ def generate_initial_answer(
# Use the query info from the base document retrieval # Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results) query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None: if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses( for tool_response in yield_search_responses(

View File

@ -10,6 +10,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults, QuestionAnswerResults,
) )
from onyx.chat.models import SubQuestionPiece from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -141,5 +142,10 @@ def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
result.query_info for result in results if result.query_info is not None result.query_info for result in results if result.query_info is not None
] ]
if len(query_infos) == 0: if len(query_infos) == 0:
return SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
raise ValueError("No query info found") raise ValueError("No query info found")
return query_infos[0] return query_infos[0]

View File

@ -26,6 +26,8 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse from onyx.chat.models import ToolResponse
from onyx.configs.agent_configs import ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.context.search.models import SearchRequest from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_runner import ToolCallKickoff from onyx.tools.tool_runner import ToolCallKickoff
@ -58,7 +60,6 @@ def _parse_agent_event(
# 1. It's a list of the names of every place we dispatch a custom event # 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event # 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event": if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs": if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"]) return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries": elif event["name"] == "subqueries":
@ -133,15 +134,30 @@ def _manage_async_event_streaming(
return _yield_async_to_sync() return _yield_async_to_sync()
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.message_id if config else None
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
):
print(event)
return []
def run_graph( def run_graph(
compiled_graph: CompiledStateGraph, compiled_graph: CompiledStateGraph,
config: AgentSearchConfig, config: AgentSearchConfig,
input: BasicInput | MainInput_a, input: BasicInput | MainInput_a,
) -> AnswerStream: ) -> AnswerStream:
# TODO: add these to the environment config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
# config.perform_initial_search_path_decision = False config.allow_refinement = ALLOW_REFINEMENT
config.perform_initial_search_decomposition = True
config.allow_refinement = True
for event in _manage_async_event_streaming( for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input compiled_graph=compiled_graph, config=config, graph_input=input
@ -152,8 +168,8 @@ def run_graph(
yield parsed_object yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based # It doesn't actually take very long to load the graph, but we'd rather
# on dev mode or not # not compile it again on every request.
def load_compiled_graph() -> CompiledStateGraph: def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH global _COMPILED_GRAPH
if _COMPILED_GRAPH is None: if _COMPILED_GRAPH is None:
@ -176,13 +192,11 @@ def run_main_graph(
yield from run_graph(compiled_graph, config, input) yield from run_graph(compiled_graph, config, input)
# TODO: unify input types, especially prosearchconfig
def run_basic_graph( def run_basic_graph(
config: AgentSearchConfig, config: AgentSearchConfig,
) -> AnswerStream: ) -> AnswerStream:
graph = basic_graph_builder() graph = basic_graph_builder()
compiled_graph = graph.compile() compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput() input = BasicInput()
return run_graph(compiled_graph, config, input) return run_graph(compiled_graph, config, input)

View File

@ -1,5 +1,8 @@
import os import os
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
ALLOW_REFINEMENT = True
AGENT_DEFAULT_RETRIEVAL_HITS = 15 AGENT_DEFAULT_RETRIEVAL_HITS = 15
AGENT_DEFAULT_RERANKING_HITS = 10 AGENT_DEFAULT_RERANKING_HITS = 10
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8 AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8