From ca1f176c6191dcf2d2b0eb411663c8e3d877c49d Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 23 Jan 2025 10:18:22 -0800 Subject: [PATCH] fixed basic flow citations and second test --- .../basic/nodes/basic_use_tool_response.py | 18 ++- .../agent_search/basic/nodes/tool_call.py | 2 + .../onyx/agents/agent_search/basic/utils.py | 29 ++++- .../search/search_tool.py | 5 +- .../search_like_tool_utils.py | 1 - backend/tests/unit/onyx/chat/conftest.py | 103 +++++++++++++----- backend/tests/unit/onyx/chat/test_answer.py | 33 ++++-- 7 files changed, 141 insertions(+), 50 deletions(-) diff --git a/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py index 81ca3cab7..c251db434 100644 --- a/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py +++ b/backend/onyx/agents/agent_search/basic/nodes/basic_use_tool_response.py @@ -7,11 +7,11 @@ from onyx.agents.agent_search.basic.states import BasicState from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.models import LlmDoc -from onyx.tools.tool_implementations.search_like_tool_utils import ( - FINAL_CONTEXT_DOCUMENTS_ID, +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_DOC_CONTENT_ID, ) from onyx.tools.tool_implementations.search_like_tool_utils import ( - ORIGINAL_CONTEXT_DOCUMENTS_ID, + FINAL_CONTEXT_DOCUMENTS_ID, ) @@ -34,11 +34,12 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO using_tool_calling_llm=agent_config.using_tool_calling_llm, ) + final_search_results = [] initial_search_results = [] for yield_item in tool_call_responses: if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID: - cast(list[LlmDoc], yield_item.response) - elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID: + final_search_results = cast(list[LlmDoc], yield_item.response) + elif yield_item.id == SEARCH_DOC_CONTENT_ID: search_contexts = yield_item.response.contexts for doc in search_contexts: if doc.document_id not in initial_search_results: @@ -52,6 +53,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO ) # For now, we don't do multiple tool calls, so we ignore the tool_message - process_llm_stream(stream, True) + process_llm_stream( + stream, + True, + final_search_results=final_search_results, + displayed_search_results=initial_search_results, + ) return BasicOutput() diff --git a/backend/onyx/agents/agent_search/basic/nodes/tool_call.py b/backend/onyx/agents/agent_search/basic/nodes/tool_call.py index 00f2c629b..7364fa25f 100644 --- a/backend/onyx/agents/agent_search/basic/nodes/tool_call.py +++ b/backend/onyx/agents/agent_search/basic/nodes/tool_call.py @@ -42,11 +42,13 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate: tool_runner = ToolRunner(tool, tool_args) tool_kickoff = tool_runner.kickoff() + print("tool_kickoff", tool_kickoff) # TODO: custom events for yields emit_packet(tool_kickoff) tool_responses = [] for response in tool_runner.tool_responses(): + print("response", response.id) tool_responses.append(response) emit_packet(response) diff --git a/backend/onyx/agents/agent_search/basic/utils.py b/backend/onyx/agents/agent_search/basic/utils.py index 7257770ca..508f1fc14 100644 --- a/backend/onyx/agents/agent_search/basic/utils.py +++ b/backend/onyx/agents/agent_search/basic/utils.py @@ -6,7 +6,12 @@ from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from onyx.chat.models import LlmDoc -from onyx.chat.models import OnyxAnswerPiece +from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler +from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler +from onyx.chat.stream_processing.answer_response_handler import ( + DummyAnswerResponseHandler, +) +from onyx.chat.stream_processing.utils import map_document_id_order from onyx.utils.logger import setup_logger logger = setup_logger() @@ -29,6 +34,18 @@ def process_llm_stream( tool_call_chunk = AIMessageChunk(content="") # for response in response_handler_manager.handle_llm_response(stream): + print("final_search_results", final_search_results) + print("displayed_search_results", displayed_search_results) + if final_search_results and displayed_search_results: + answer_handler: AnswerResponseHandler = CitationResponseHandler( + context_docs=final_search_results, + final_doc_id_to_rank_map=map_document_id_order(final_search_results), + display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), + ) + else: + answer_handler = DummyAnswerResponseHandler() + + print("entering stream") # This stream will be the llm answer if no tool is chosen. When a tool is chosen, # the stream will contain AIMessageChunks with tool call information. for response in stream: @@ -44,9 +61,11 @@ def process_llm_stream( tool_call_chunk += response # type: ignore elif should_stream_answer: # TODO: handle emitting of CitationInfo - dispatch_custom_event( - "basic_response", - OnyxAnswerPiece(answer_piece=answer_piece), - ) + for response_part in answer_handler.handle_response_part(response, []): + print("resp part", response_part) + dispatch_custom_event( + "basic_response", + response_part, + ) return cast(AIMessageChunk, tool_call_chunk) diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index f139e114f..e4d698091 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -49,9 +49,6 @@ from onyx.tools.tool_implementations.search_like_tool_utils import ( from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) -from onyx.tools.tool_implementations.search_like_tool_utils import ( - ORIGINAL_CONTEXT_DOCUMENTS_ID, -) from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro @@ -395,7 +392,7 @@ class SearchTool(Tool): final_search_results = cast(list[LlmDoc], yield_item.response) elif ( isinstance(yield_item, ToolResponse) - and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID + and yield_item.id == SEARCH_DOC_CONTENT_ID ): search_contexts = yield_item.response.contexts # original_doc_search_rank = 1 diff --git a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py index cf4dfda08..2b307c2c2 100644 --- a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py @@ -15,7 +15,6 @@ from onyx.tools.message import ToolCallSummary from onyx.tools.models import ToolResponse -ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content" FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 684cfb620..e4170efb9 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -7,17 +7,23 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.chat.chat_utils import llm_doc_from_inference_section from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import LlmDoc +from onyx.chat.models import OnyxContext +from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import DocumentSource +from onyx.context.search.models import InferenceChunk +from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolResponse +from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, @@ -82,37 +88,83 @@ def mock_llm() -> MagicMock: @pytest.fixture -def mock_search_results() -> list[LlmDoc]: +def mock_inference_sections() -> list[InferenceSection]: return [ - LlmDoc( - content="Search result 1", - source_type=DocumentSource.WEB, - metadata={"id": "doc1"}, - document_id="doc1", - blurb="Blurb 1", - semantic_identifier="Semantic ID 1", - updated_at=datetime(2023, 1, 1), - link="https://example.com/doc1", - source_links={0: "https://example.com/doc1"}, - match_highlights=[], + InferenceSection( + combined_content="Search result 1", + center_chunk=InferenceChunk( + chunk_id=1, + section_continuation=False, + title=None, + boost=1, + recency_bias=0.5, + score=1.0, + hidden=False, + content="Search result 1", + source_type=DocumentSource.WEB, + metadata={"id": "doc1"}, + document_id="doc1", + blurb="Blurb 1", + semantic_identifier="Semantic ID 1", + updated_at=datetime(2023, 1, 1), + source_links={0: "https://example.com/doc1"}, + match_highlights=[], + ), + chunks=MagicMock(), ), - LlmDoc( - content="Search result 2", - source_type=DocumentSource.WEB, - metadata={"id": "doc2"}, - document_id="doc2", - blurb="Blurb 2", - semantic_identifier="Semantic ID 2", - updated_at=datetime(2023, 1, 2), - link="https://example.com/doc2", - source_links={0: "https://example.com/doc2"}, - match_highlights=[], + InferenceSection( + combined_content="Search result 2", + center_chunk=InferenceChunk( + chunk_id=2, + section_continuation=False, + title=None, + boost=1, + recency_bias=0.5, + score=1.0, + hidden=False, + content="Search result 2", + source_type=DocumentSource.WEB, + metadata={"id": "doc2"}, + document_id="doc2", + blurb="Blurb 2", + semantic_identifier="Semantic ID 2", + updated_at=datetime(2023, 1, 2), + source_links={0: "https://example.com/doc2"}, + match_highlights=[], + ), + chunks=MagicMock(), ), ] @pytest.fixture -def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock: +def mock_search_results( + mock_inference_sections: list[InferenceSection], +) -> list[LlmDoc]: + return [ + llm_doc_from_inference_section(section) for section in mock_inference_sections + ] + + +@pytest.fixture +def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts: + return OnyxContexts( + contexts=[ + OnyxContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) + for section in mock_inference_sections + ] + ) + + +@pytest.fixture +def mock_search_tool( + mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc] +) -> MagicMock: mock_tool = MagicMock(spec=SearchTool) mock_tool.name = "search" mock_tool.build_tool_message_content.return_value = "search_response" @@ -121,7 +173,8 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock: json.loads(doc.model_dump_json()) for doc in mock_search_results ] mock_tool.run.return_value = [ - ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results) + ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts), + ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results), ] mock_tool.tool_definition.return_value = { "type": "function", diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 50b27c17d..312cca1fc 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -17,6 +17,7 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import LlmDoc from onyx.chat.models import OnyxAnswerPiece +from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason @@ -25,6 +26,10 @@ from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolCallFinalResult from onyx.tools.models import ToolCallKickoff from onyx.tools.models import ToolResponse +from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID +from onyx.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS from tests.unit.onyx.chat.conftest import QUERY @@ -215,12 +220,13 @@ def test_answer_with_search_call( def test_answer_with_search_no_tool_calling( answer_instance: Answer, mock_search_results: list[LlmDoc], + mock_contexts: OnyxContexts, mock_search_tool: MagicMock, ) -> None: - answer_instance.tools = [mock_search_tool] + answer_instance.agent_search_config.tools = [mock_search_tool] # Set up the LLM mock to return an answer - mock_llm = cast(Mock, answer_instance.llm) + mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm) mock_llm.stream.return_value = [ AIMessageChunk(content="Based on the search results, "), AIMessageChunk(content="the answer is abc[1]. "), @@ -228,10 +234,15 @@ def test_answer_with_search_no_tool_calling( ] # Force non-tool calling behavior - answer_instance.using_tool_calling_llm = False + answer_instance.agent_search_config.using_tool_calling_llm = False # Process the output output = list(answer_instance.processed_streamed_output) + print("-" * 50) + for v in output: + print(v) + print() + print("-" * 50) # Assertions assert len(output) == 7 @@ -239,21 +250,25 @@ def test_answer_with_search_no_tool_calling( tool_name="search", tool_args=DEFAULT_SEARCH_ARGS ) assert output[1] == ToolResponse( - id="final_context_documents", + id=SEARCH_DOC_CONTENT_ID, + response=mock_contexts, + ) + assert output[2] == ToolResponse( + id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results, ) - assert output[2] == ToolCallFinalResult( + assert output[3] == ToolCallFinalResult( tool_name="search", tool_args=DEFAULT_SEARCH_ARGS, tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], ) - assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ") + assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ") expected_citation = CitationInfo(citation_num=1, document_id="doc1") - assert output[4] == expected_citation - assert output[5] == OnyxAnswerPiece( + assert output[5] == expected_citation + assert output[6] == OnyxAnswerPiece( answer_piece="the answer is abc[[1]](https://example.com/doc1). " ) - assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.") + assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.") expected_answer = ( "Based on the search results, "