addressing PR comments

This commit is contained in:
Evan Lohn 2025-01-30 15:36:11 -08:00
parent 2b8cd63b34
commit bb6d55783e
7 changed files with 45 additions and 29 deletions

View File

@ -85,7 +85,6 @@ if __name__ == "__main__":
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(logs="")
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:

View File

@ -16,8 +16,9 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
class BasicInput(BaseModel):
# TODO: subclass global log update state
logs: str = ""
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
_unused: bool = True
## Graph Output State

View File

@ -16,17 +16,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: handle citations here; below is what was previously passed in
# see basic_use_tool_response.py for where these variables come from
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
def process_llm_stream(
stream: Iterator[BaseMessage],
messages: Iterator[BaseMessage],
should_stream_answer: bool,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
@ -46,20 +38,20 @@ def process_llm_stream(
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
answer_piece = response.content
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}")
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += response # type: ignore
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(response, []):
for response_part in answer_handler.handle_response_part(message, []):
dispatch_custom_event(
"basic_response",
response_part,

View File

@ -87,6 +87,7 @@ def generate_initial_answer(
# Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses(

View File

@ -10,6 +10,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@ -141,5 +142,10 @@ def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
result.query_info for result in results if result.query_info is not None
]
if len(query_infos) == 0:
return SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
raise ValueError("No query info found")
return query_infos[0]

View File

@ -26,6 +26,8 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.configs.agent_configs import ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_runner import ToolCallKickoff
@ -58,7 +60,6 @@ def _parse_agent_event(
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
@ -133,15 +134,30 @@ def _manage_async_event_streaming(
return _yield_async_to_sync()
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.message_id if config else None
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
):
print(event)
return []
def run_graph(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
# TODO: add these to the environment
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
config.allow_refinement = True
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
config.allow_refinement = ALLOW_REFINEMENT
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
@ -152,8 +168,8 @@ def run_graph(
yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
# It doesn't actually take very long to load the graph, but we'd rather
# not compile it again on every request.
def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
@ -176,13 +192,11 @@ def run_main_graph(
yield from run_graph(compiled_graph, config, input)
# TODO: unify input types, especially prosearchconfig
def run_basic_graph(
config: AgentSearchConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput()
return run_graph(compiled_graph, config, input)

View File

@ -1,5 +1,8 @@
import os
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
ALLOW_REFINEMENT = True
AGENT_DEFAULT_RETRIEVAL_HITS = 15
AGENT_DEFAULT_RERANKING_HITS = 10
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8