fixed basic flow citations and second test

This commit is contained in:
Evan Lohn 2025-01-23 10:18:22 -08:00
parent 3ced9bc28b
commit ca1f176c61
7 changed files with 141 additions and 50 deletions

View File

@ -7,11 +7,11 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import LlmDoc
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
ORIGINAL_CONTEXT_DOCUMENTS_ID,
FINAL_CONTEXT_DOCUMENTS_ID,
)
@ -34,11 +34,12 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
using_tool_calling_llm=agent_config.using_tool_calling_llm,
)
final_search_results = []
initial_search_results = []
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
cast(list[LlmDoc], yield_item.response)
elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = yield_item.response.contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
@ -52,6 +53,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
process_llm_stream(stream, True)
process_llm_stream(
stream,
True,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)
return BasicOutput()

View File

@ -42,11 +42,13 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
print("tool_kickoff", tool_kickoff)
# TODO: custom events for yields
emit_packet(tool_kickoff)
tool_responses = []
for response in tool_runner.tool_responses():
print("response", response.id)
tool_responses.append(response)
emit_packet(response)

View File

@ -6,7 +6,12 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -29,6 +34,18 @@ def process_llm_stream(
tool_call_chunk = AIMessageChunk(content="")
# for response in response_handler_manager.handle_llm_response(stream):
print("final_search_results", final_search_results)
print("displayed_search_results", displayed_search_results)
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = DummyAnswerResponseHandler()
print("entering stream")
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
@ -44,9 +61,11 @@ def process_llm_stream(
tool_call_chunk += response # type: ignore
elif should_stream_answer:
# TODO: handle emitting of CitationInfo
dispatch_custom_event(
"basic_response",
OnyxAnswerPiece(answer_piece=answer_piece),
)
for response_part in answer_handler.handle_response_part(response, []):
print("resp part", response_part)
dispatch_custom_event(
"basic_response",
response_part,
)
return cast(AIMessageChunk, tool_call_chunk)

View File

@ -49,9 +49,6 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
ORIGINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@ -395,7 +392,7 @@ class SearchTool(Tool):
final_search_results = cast(list[LlmDoc], yield_item.response)
elif (
isinstance(yield_item, ToolResponse)
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
and yield_item.id == SEARCH_DOC_CONTENT_ID
):
search_contexts = yield_item.response.contexts
# original_doc_search_rank = 1

View File

@ -15,7 +15,6 @@ from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"

View File

@ -7,17 +7,23 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
@ -82,37 +88,83 @@ def mock_llm() -> MagicMock:
@pytest.fixture
def mock_search_results() -> list[LlmDoc]:
def mock_inference_sections() -> list[InferenceSection]:
return [
LlmDoc(
content="Search result 1",
source_type=DocumentSource.WEB,
metadata={"id": "doc1"},
document_id="doc1",
blurb="Blurb 1",
semantic_identifier="Semantic ID 1",
updated_at=datetime(2023, 1, 1),
link="https://example.com/doc1",
source_links={0: "https://example.com/doc1"},
match_highlights=[],
InferenceSection(
combined_content="Search result 1",
center_chunk=InferenceChunk(
chunk_id=1,
section_continuation=False,
title=None,
boost=1,
recency_bias=0.5,
score=1.0,
hidden=False,
content="Search result 1",
source_type=DocumentSource.WEB,
metadata={"id": "doc1"},
document_id="doc1",
blurb="Blurb 1",
semantic_identifier="Semantic ID 1",
updated_at=datetime(2023, 1, 1),
source_links={0: "https://example.com/doc1"},
match_highlights=[],
),
chunks=MagicMock(),
),
LlmDoc(
content="Search result 2",
source_type=DocumentSource.WEB,
metadata={"id": "doc2"},
document_id="doc2",
blurb="Blurb 2",
semantic_identifier="Semantic ID 2",
updated_at=datetime(2023, 1, 2),
link="https://example.com/doc2",
source_links={0: "https://example.com/doc2"},
match_highlights=[],
InferenceSection(
combined_content="Search result 2",
center_chunk=InferenceChunk(
chunk_id=2,
section_continuation=False,
title=None,
boost=1,
recency_bias=0.5,
score=1.0,
hidden=False,
content="Search result 2",
source_type=DocumentSource.WEB,
metadata={"id": "doc2"},
document_id="doc2",
blurb="Blurb 2",
semantic_identifier="Semantic ID 2",
updated_at=datetime(2023, 1, 2),
source_links={0: "https://example.com/doc2"},
match_highlights=[],
),
chunks=MagicMock(),
),
]
@pytest.fixture
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
def mock_search_results(
mock_inference_sections: list[InferenceSection],
) -> list[LlmDoc]:
return [
llm_doc_from_inference_section(section) for section in mock_inference_sections
]
@pytest.fixture
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
return OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in mock_inference_sections
]
)
@pytest.fixture
def mock_search_tool(
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
) -> MagicMock:
mock_tool = MagicMock(spec=SearchTool)
mock_tool.name = "search"
mock_tool.build_tool_message_content.return_value = "search_response"
@ -121,7 +173,8 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
json.loads(doc.model_dump_json()) for doc in mock_search_results
]
mock_tool.run.return_value = [
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results)
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
]
mock_tool.tool_definition.return_value = {
"type": "function",

View File

@ -17,6 +17,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@ -25,6 +26,10 @@ from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
from tests.unit.onyx.chat.conftest import QUERY
@ -215,12 +220,13 @@ def test_answer_with_search_call(
def test_answer_with_search_no_tool_calling(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
) -> None:
answer_instance.tools = [mock_search_tool]
answer_instance.agent_search_config.tools = [mock_search_tool]
# Set up the LLM mock to return an answer
mock_llm = cast(Mock, answer_instance.llm)
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="Based on the search results, "),
AIMessageChunk(content="the answer is abc[1]. "),
@ -228,10 +234,15 @@ def test_answer_with_search_no_tool_calling(
]
# Force non-tool calling behavior
answer_instance.using_tool_calling_llm = False
answer_instance.agent_search_config.using_tool_calling_llm = False
# Process the output
output = list(answer_instance.processed_streamed_output)
print("-" * 50)
for v in output:
print(v)
print()
print("-" * 50)
# Assertions
assert len(output) == 7
@ -239,21 +250,25 @@ def test_answer_with_search_no_tool_calling(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id="final_context_documents",
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
assert output[3] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "