mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-20 00:30:10 +02:00
fixed basic flow citations and second test
This commit is contained in:
parent
3ced9bc28b
commit
ca1f176c61
@ -7,11 +7,11 @@ from onyx.agents.agent_search.basic.states import BasicState
|
|||||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
SEARCH_DOC_CONTENT_ID,
|
||||||
)
|
)
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -34,11 +34,12 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
|||||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_search_results = []
|
||||||
initial_search_results = []
|
initial_search_results = []
|
||||||
for yield_item in tool_call_responses:
|
for yield_item in tool_call_responses:
|
||||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||||
cast(list[LlmDoc], yield_item.response)
|
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||||
elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID:
|
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||||
search_contexts = yield_item.response.contexts
|
search_contexts = yield_item.response.contexts
|
||||||
for doc in search_contexts:
|
for doc in search_contexts:
|
||||||
if doc.document_id not in initial_search_results:
|
if doc.document_id not in initial_search_results:
|
||||||
@ -52,6 +53,11 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||||
process_llm_stream(stream, True)
|
process_llm_stream(
|
||||||
|
stream,
|
||||||
|
True,
|
||||||
|
final_search_results=final_search_results,
|
||||||
|
displayed_search_results=initial_search_results,
|
||||||
|
)
|
||||||
|
|
||||||
return BasicOutput()
|
return BasicOutput()
|
||||||
|
@ -42,11 +42,13 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
|||||||
tool_runner = ToolRunner(tool, tool_args)
|
tool_runner = ToolRunner(tool, tool_args)
|
||||||
tool_kickoff = tool_runner.kickoff()
|
tool_kickoff = tool_runner.kickoff()
|
||||||
|
|
||||||
|
print("tool_kickoff", tool_kickoff)
|
||||||
# TODO: custom events for yields
|
# TODO: custom events for yields
|
||||||
emit_packet(tool_kickoff)
|
emit_packet(tool_kickoff)
|
||||||
|
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
for response in tool_runner.tool_responses():
|
for response in tool_runner.tool_responses():
|
||||||
|
print("response", response.id)
|
||||||
tool_responses.append(response)
|
tool_responses.append(response)
|
||||||
emit_packet(response)
|
emit_packet(response)
|
||||||
|
|
||||||
|
@ -6,7 +6,12 @@ from langchain_core.messages import AIMessageChunk
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||||
|
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||||
|
from onyx.chat.stream_processing.answer_response_handler import (
|
||||||
|
DummyAnswerResponseHandler,
|
||||||
|
)
|
||||||
|
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -29,6 +34,18 @@ def process_llm_stream(
|
|||||||
tool_call_chunk = AIMessageChunk(content="")
|
tool_call_chunk = AIMessageChunk(content="")
|
||||||
# for response in response_handler_manager.handle_llm_response(stream):
|
# for response in response_handler_manager.handle_llm_response(stream):
|
||||||
|
|
||||||
|
print("final_search_results", final_search_results)
|
||||||
|
print("displayed_search_results", displayed_search_results)
|
||||||
|
if final_search_results and displayed_search_results:
|
||||||
|
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||||
|
context_docs=final_search_results,
|
||||||
|
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||||
|
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer_handler = DummyAnswerResponseHandler()
|
||||||
|
|
||||||
|
print("entering stream")
|
||||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||||
# the stream will contain AIMessageChunks with tool call information.
|
# the stream will contain AIMessageChunks with tool call information.
|
||||||
for response in stream:
|
for response in stream:
|
||||||
@ -44,9 +61,11 @@ def process_llm_stream(
|
|||||||
tool_call_chunk += response # type: ignore
|
tool_call_chunk += response # type: ignore
|
||||||
elif should_stream_answer:
|
elif should_stream_answer:
|
||||||
# TODO: handle emitting of CitationInfo
|
# TODO: handle emitting of CitationInfo
|
||||||
|
for response_part in answer_handler.handle_response_part(response, []):
|
||||||
|
print("resp part", response_part)
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"basic_response",
|
"basic_response",
|
||||||
OnyxAnswerPiece(answer_piece=answer_piece),
|
response_part,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(AIMessageChunk, tool_call_chunk)
|
return cast(AIMessageChunk, tool_call_chunk)
|
||||||
|
@ -49,9 +49,6 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
|
|||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
|
||||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
|
||||||
)
|
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.special_types import JSON_ro
|
from onyx.utils.special_types import JSON_ro
|
||||||
|
|
||||||
@ -395,7 +392,7 @@ class SearchTool(Tool):
|
|||||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||||
elif (
|
elif (
|
||||||
isinstance(yield_item, ToolResponse)
|
isinstance(yield_item, ToolResponse)
|
||||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
and yield_item.id == SEARCH_DOC_CONTENT_ID
|
||||||
):
|
):
|
||||||
search_contexts = yield_item.response.contexts
|
search_contexts = yield_item.response.contexts
|
||||||
# original_doc_search_rank = 1
|
# original_doc_search_rank = 1
|
||||||
|
@ -15,7 +15,6 @@ from onyx.tools.message import ToolCallSummary
|
|||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
|
|
||||||
|
|
||||||
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,17 +7,23 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langchain_core.messages import SystemMessage
|
from langchain_core.messages import SystemMessage
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||||
|
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||||
from onyx.chat.models import AnswerStyleConfig
|
from onyx.chat.models import AnswerStyleConfig
|
||||||
from onyx.chat.models import CitationConfig
|
from onyx.chat.models import CitationConfig
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
|
from onyx.chat.models import OnyxContext
|
||||||
|
from onyx.chat.models import OnyxContexts
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.context.search.models import InferenceChunk
|
||||||
|
from onyx.context.search.models import InferenceSection
|
||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import SearchRequest
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.interfaces import LLMConfig
|
from onyx.llm.interfaces import LLMConfig
|
||||||
from onyx.tools.force import ForceUseTool
|
from onyx.tools.force import ForceUseTool
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
|
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
@ -82,9 +88,18 @@ def mock_llm() -> MagicMock:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_search_results() -> list[LlmDoc]:
|
def mock_inference_sections() -> list[InferenceSection]:
|
||||||
return [
|
return [
|
||||||
LlmDoc(
|
InferenceSection(
|
||||||
|
combined_content="Search result 1",
|
||||||
|
center_chunk=InferenceChunk(
|
||||||
|
chunk_id=1,
|
||||||
|
section_continuation=False,
|
||||||
|
title=None,
|
||||||
|
boost=1,
|
||||||
|
recency_bias=0.5,
|
||||||
|
score=1.0,
|
||||||
|
hidden=False,
|
||||||
content="Search result 1",
|
content="Search result 1",
|
||||||
source_type=DocumentSource.WEB,
|
source_type=DocumentSource.WEB,
|
||||||
metadata={"id": "doc1"},
|
metadata={"id": "doc1"},
|
||||||
@ -92,11 +107,21 @@ def mock_search_results() -> list[LlmDoc]:
|
|||||||
blurb="Blurb 1",
|
blurb="Blurb 1",
|
||||||
semantic_identifier="Semantic ID 1",
|
semantic_identifier="Semantic ID 1",
|
||||||
updated_at=datetime(2023, 1, 1),
|
updated_at=datetime(2023, 1, 1),
|
||||||
link="https://example.com/doc1",
|
|
||||||
source_links={0: "https://example.com/doc1"},
|
source_links={0: "https://example.com/doc1"},
|
||||||
match_highlights=[],
|
match_highlights=[],
|
||||||
),
|
),
|
||||||
LlmDoc(
|
chunks=MagicMock(),
|
||||||
|
),
|
||||||
|
InferenceSection(
|
||||||
|
combined_content="Search result 2",
|
||||||
|
center_chunk=InferenceChunk(
|
||||||
|
chunk_id=2,
|
||||||
|
section_continuation=False,
|
||||||
|
title=None,
|
||||||
|
boost=1,
|
||||||
|
recency_bias=0.5,
|
||||||
|
score=1.0,
|
||||||
|
hidden=False,
|
||||||
content="Search result 2",
|
content="Search result 2",
|
||||||
source_type=DocumentSource.WEB,
|
source_type=DocumentSource.WEB,
|
||||||
metadata={"id": "doc2"},
|
metadata={"id": "doc2"},
|
||||||
@ -104,15 +129,42 @@ def mock_search_results() -> list[LlmDoc]:
|
|||||||
blurb="Blurb 2",
|
blurb="Blurb 2",
|
||||||
semantic_identifier="Semantic ID 2",
|
semantic_identifier="Semantic ID 2",
|
||||||
updated_at=datetime(2023, 1, 2),
|
updated_at=datetime(2023, 1, 2),
|
||||||
link="https://example.com/doc2",
|
|
||||||
source_links={0: "https://example.com/doc2"},
|
source_links={0: "https://example.com/doc2"},
|
||||||
match_highlights=[],
|
match_highlights=[],
|
||||||
),
|
),
|
||||||
|
chunks=MagicMock(),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
def mock_search_results(
|
||||||
|
mock_inference_sections: list[InferenceSection],
|
||||||
|
) -> list[LlmDoc]:
|
||||||
|
return [
|
||||||
|
llm_doc_from_inference_section(section) for section in mock_inference_sections
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
||||||
|
return OnyxContexts(
|
||||||
|
contexts=[
|
||||||
|
OnyxContext(
|
||||||
|
content=section.combined_content,
|
||||||
|
document_id=section.center_chunk.document_id,
|
||||||
|
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||||
|
blurb=section.center_chunk.blurb,
|
||||||
|
)
|
||||||
|
for section in mock_inference_sections
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_search_tool(
|
||||||
|
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
||||||
|
) -> MagicMock:
|
||||||
mock_tool = MagicMock(spec=SearchTool)
|
mock_tool = MagicMock(spec=SearchTool)
|
||||||
mock_tool.name = "search"
|
mock_tool.name = "search"
|
||||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||||
@ -121,7 +173,8 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
|||||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||||
]
|
]
|
||||||
mock_tool.run.return_value = [
|
mock_tool.run.return_value = [
|
||||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results)
|
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
||||||
|
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
||||||
]
|
]
|
||||||
mock_tool.tool_definition.return_value = {
|
mock_tool.tool_definition.return_value = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -17,6 +17,7 @@ from onyx.chat.models import AnswerStyleConfig
|
|||||||
from onyx.chat.models import CitationInfo
|
from onyx.chat.models import CitationInfo
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.models import OnyxAnswerPiece
|
||||||
|
from onyx.chat.models import OnyxContexts
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.models import StreamStopInfo
|
from onyx.chat.models import StreamStopInfo
|
||||||
from onyx.chat.models import StreamStopReason
|
from onyx.chat.models import StreamStopReason
|
||||||
@ -25,6 +26,10 @@ from onyx.tools.force import ForceUseTool
|
|||||||
from onyx.tools.models import ToolCallFinalResult
|
from onyx.tools.models import ToolCallFinalResult
|
||||||
from onyx.tools.models import ToolCallKickoff
|
from onyx.tools.models import ToolCallKickoff
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
|
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||||
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
|
)
|
||||||
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||||
from tests.unit.onyx.chat.conftest import QUERY
|
from tests.unit.onyx.chat.conftest import QUERY
|
||||||
|
|
||||||
@ -215,12 +220,13 @@ def test_answer_with_search_call(
|
|||||||
def test_answer_with_search_no_tool_calling(
|
def test_answer_with_search_no_tool_calling(
|
||||||
answer_instance: Answer,
|
answer_instance: Answer,
|
||||||
mock_search_results: list[LlmDoc],
|
mock_search_results: list[LlmDoc],
|
||||||
|
mock_contexts: OnyxContexts,
|
||||||
mock_search_tool: MagicMock,
|
mock_search_tool: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
answer_instance.tools = [mock_search_tool]
|
answer_instance.agent_search_config.tools = [mock_search_tool]
|
||||||
|
|
||||||
# Set up the LLM mock to return an answer
|
# Set up the LLM mock to return an answer
|
||||||
mock_llm = cast(Mock, answer_instance.llm)
|
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
||||||
mock_llm.stream.return_value = [
|
mock_llm.stream.return_value = [
|
||||||
AIMessageChunk(content="Based on the search results, "),
|
AIMessageChunk(content="Based on the search results, "),
|
||||||
AIMessageChunk(content="the answer is abc[1]. "),
|
AIMessageChunk(content="the answer is abc[1]. "),
|
||||||
@ -228,10 +234,15 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Force non-tool calling behavior
|
# Force non-tool calling behavior
|
||||||
answer_instance.using_tool_calling_llm = False
|
answer_instance.agent_search_config.using_tool_calling_llm = False
|
||||||
|
|
||||||
# Process the output
|
# Process the output
|
||||||
output = list(answer_instance.processed_streamed_output)
|
output = list(answer_instance.processed_streamed_output)
|
||||||
|
print("-" * 50)
|
||||||
|
for v in output:
|
||||||
|
print(v)
|
||||||
|
print()
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert len(output) == 7
|
assert len(output) == 7
|
||||||
@ -239,21 +250,25 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||||
)
|
)
|
||||||
assert output[1] == ToolResponse(
|
assert output[1] == ToolResponse(
|
||||||
id="final_context_documents",
|
id=SEARCH_DOC_CONTENT_ID,
|
||||||
|
response=mock_contexts,
|
||||||
|
)
|
||||||
|
assert output[2] == ToolResponse(
|
||||||
|
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
response=mock_search_results,
|
response=mock_search_results,
|
||||||
)
|
)
|
||||||
assert output[2] == ToolCallFinalResult(
|
assert output[3] == ToolCallFinalResult(
|
||||||
tool_name="search",
|
tool_name="search",
|
||||||
tool_args=DEFAULT_SEARCH_ARGS,
|
tool_args=DEFAULT_SEARCH_ARGS,
|
||||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||||
)
|
)
|
||||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||||
assert output[4] == expected_citation
|
assert output[5] == expected_citation
|
||||||
assert output[5] == OnyxAnswerPiece(
|
assert output[6] == OnyxAnswerPiece(
|
||||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||||
)
|
)
|
||||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||||
|
|
||||||
expected_answer = (
|
expected_answer = (
|
||||||
"Based on the search results, "
|
"Based on the search results, "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user