mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-17 23:30:07 +02:00
fixed basic flow citations and second test
This commit is contained in:
parent
3ced9bc28b
commit
ca1f176c61
@ -7,11 +7,11 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
|
||||
|
||||
@ -34,11 +34,12 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
@ -52,6 +53,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
process_llm_stream(stream, True)
|
||||
process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput()
|
||||
|
@ -42,11 +42,13 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
print("tool_kickoff", tool_kickoff)
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
print("response", response.id)
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
|
@ -6,7 +6,12 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -29,6 +34,18 @@ def process_llm_stream(
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
print("final_search_results", final_search_results)
|
||||
print("displayed_search_results", displayed_search_results)
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = DummyAnswerResponseHandler()
|
||||
|
||||
print("entering stream")
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
@ -44,9 +61,11 @@ def process_llm_stream(
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
# TODO: handle emitting of CitationInfo
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
OnyxAnswerPiece(answer_piece=answer_piece),
|
||||
)
|
||||
for response_part in answer_handler.handle_response_part(response, []):
|
||||
print("resp part", response_part)
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
)
|
||||
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
|
@ -49,9 +49,6 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@ -395,7 +392,7 @@ class SearchTool(Tool):
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||
and yield_item.id == SEARCH_DOC_CONTENT_ID
|
||||
):
|
||||
search_contexts = yield_item.response.contexts
|
||||
# original_doc_search_rank = 1
|
||||
|
@ -15,7 +15,6 @@ from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
|
||||
|
@ -7,17 +7,23 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
@ -82,37 +88,83 @@ def mock_llm() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_results() -> list[LlmDoc]:
|
||||
def mock_inference_sections() -> list[InferenceSection]:
|
||||
return [
|
||||
LlmDoc(
|
||||
content="Search result 1",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc1"},
|
||||
document_id="doc1",
|
||||
blurb="Blurb 1",
|
||||
semantic_identifier="Semantic ID 1",
|
||||
updated_at=datetime(2023, 1, 1),
|
||||
link="https://example.com/doc1",
|
||||
source_links={0: "https://example.com/doc1"},
|
||||
match_highlights=[],
|
||||
InferenceSection(
|
||||
combined_content="Search result 1",
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=1,
|
||||
section_continuation=False,
|
||||
title=None,
|
||||
boost=1,
|
||||
recency_bias=0.5,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
content="Search result 1",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc1"},
|
||||
document_id="doc1",
|
||||
blurb="Blurb 1",
|
||||
semantic_identifier="Semantic ID 1",
|
||||
updated_at=datetime(2023, 1, 1),
|
||||
source_links={0: "https://example.com/doc1"},
|
||||
match_highlights=[],
|
||||
),
|
||||
chunks=MagicMock(),
|
||||
),
|
||||
LlmDoc(
|
||||
content="Search result 2",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc2"},
|
||||
document_id="doc2",
|
||||
blurb="Blurb 2",
|
||||
semantic_identifier="Semantic ID 2",
|
||||
updated_at=datetime(2023, 1, 2),
|
||||
link="https://example.com/doc2",
|
||||
source_links={0: "https://example.com/doc2"},
|
||||
match_highlights=[],
|
||||
InferenceSection(
|
||||
combined_content="Search result 2",
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=2,
|
||||
section_continuation=False,
|
||||
title=None,
|
||||
boost=1,
|
||||
recency_bias=0.5,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
content="Search result 2",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc2"},
|
||||
document_id="doc2",
|
||||
blurb="Blurb 2",
|
||||
semantic_identifier="Semantic ID 2",
|
||||
updated_at=datetime(2023, 1, 2),
|
||||
source_links={0: "https://example.com/doc2"},
|
||||
match_highlights=[],
|
||||
),
|
||||
chunks=MagicMock(),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
def mock_search_results(
|
||||
mock_inference_sections: list[InferenceSection],
|
||||
) -> list[LlmDoc]:
|
||||
return [
|
||||
llm_doc_from_inference_section(section) for section in mock_inference_sections
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
||||
return OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in mock_inference_sections
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(
|
||||
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
||||
) -> MagicMock:
|
||||
mock_tool = MagicMock(spec=SearchTool)
|
||||
mock_tool.name = "search"
|
||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||
@ -121,7 +173,8 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||
]
|
||||
mock_tool.run.return_value = [
|
||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results)
|
||||
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
||||
]
|
||||
mock_tool.tool_definition.return_value = {
|
||||
"type": "function",
|
||||
|
@ -17,6 +17,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
@ -25,6 +26,10 @@ from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.onyx.chat.conftest import QUERY
|
||||
|
||||
@ -215,12 +220,13 @@ def test_answer_with_search_call(
|
||||
def test_answer_with_search_no_tool_calling(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.tools = [mock_search_tool]
|
||||
answer_instance.agent_search_config.tools = [mock_search_tool]
|
||||
|
||||
# Set up the LLM mock to return an answer
|
||||
mock_llm = cast(Mock, answer_instance.llm)
|
||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="Based on the search results, "),
|
||||
AIMessageChunk(content="the answer is abc[1]. "),
|
||||
@ -228,10 +234,15 @@ def test_answer_with_search_no_tool_calling(
|
||||
]
|
||||
|
||||
# Force non-tool calling behavior
|
||||
answer_instance.using_tool_calling_llm = False
|
||||
answer_instance.agent_search_config.using_tool_calling_llm = False
|
||||
|
||||
# Process the output
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
print("-" * 50)
|
||||
for v in output:
|
||||
print(v)
|
||||
print()
|
||||
print("-" * 50)
|
||||
|
||||
# Assertions
|
||||
assert len(output) == 7
|
||||
@ -239,21 +250,25 @@ def test_answer_with_search_no_tool_calling(
|
||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=mock_contexts,
|
||||
)
|
||||
assert output[2] == ToolResponse(
|
||||
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
assert output[3] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=DEFAULT_SEARCH_ARGS,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == OnyxAnswerPiece(
|
||||
assert output[5] == expected_citation
|
||||
assert output[6] == OnyxAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
|
Loading…
x
Reference in New Issue
Block a user