mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-18 15:50:13 +02:00
199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
import json
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.messages import SystemMessage
|
|
|
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import CitationConfig
|
|
from onyx.chat.models import LlmDoc
|
|
from onyx.chat.models import OnyxContext
|
|
from onyx.chat.models import OnyxContexts
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.context.search.models import InferenceChunk
|
|
from onyx.context.search.models import InferenceSection
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.llm.interfaces import LLMConfig
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
|
)
|
|
|
|
QUERY = "Test question"
|
|
DEFAULT_SEARCH_ARGS = {"query": "search"}
|
|
|
|
|
|
@pytest.fixture
|
|
def answer_style_config() -> AnswerStyleConfig:
|
|
return AnswerStyleConfig(citation_config=CitationConfig())
|
|
|
|
|
|
@pytest.fixture
|
|
def agent_search_config(
|
|
mock_llm: LLM, mock_search_tool: SearchTool
|
|
) -> AgentSearchConfig:
|
|
return AgentSearchConfig(
|
|
search_request=SearchRequest(query=QUERY),
|
|
primary_llm=mock_llm,
|
|
fast_llm=mock_llm,
|
|
search_tool=mock_search_tool,
|
|
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
|
prompt_builder=AnswerPromptBuilder(
|
|
user_message=HumanMessage(content=QUERY),
|
|
message_history=[],
|
|
llm_config=mock_llm.config,
|
|
raw_user_query=QUERY,
|
|
raw_user_uploaded_files=[],
|
|
),
|
|
chat_session_id=None,
|
|
message_id=1,
|
|
use_persistence=True,
|
|
db_session=None,
|
|
use_agentic_search=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def prompt_config() -> PromptConfig:
|
|
return PromptConfig(
|
|
system_prompt="System prompt",
|
|
task_prompt="Task prompt",
|
|
datetime_aware=False,
|
|
include_citations=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm() -> MagicMock:
|
|
mock_llm_obj = MagicMock()
|
|
mock_llm_obj.config = LLMConfig(
|
|
model_provider="openai",
|
|
model_name="gpt-4o",
|
|
temperature=0.0,
|
|
api_key=None,
|
|
api_base=None,
|
|
api_version=None,
|
|
)
|
|
return mock_llm_obj
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_inference_sections() -> list[InferenceSection]:
|
|
return [
|
|
InferenceSection(
|
|
combined_content="Search result 1",
|
|
center_chunk=InferenceChunk(
|
|
chunk_id=1,
|
|
section_continuation=False,
|
|
title=None,
|
|
boost=1,
|
|
recency_bias=0.5,
|
|
score=1.0,
|
|
hidden=False,
|
|
content="Search result 1",
|
|
source_type=DocumentSource.WEB,
|
|
metadata={"id": "doc1"},
|
|
document_id="doc1",
|
|
blurb="Blurb 1",
|
|
semantic_identifier="Semantic ID 1",
|
|
updated_at=datetime(2023, 1, 1),
|
|
source_links={0: "https://example.com/doc1"},
|
|
match_highlights=[],
|
|
),
|
|
chunks=MagicMock(),
|
|
),
|
|
InferenceSection(
|
|
combined_content="Search result 2",
|
|
center_chunk=InferenceChunk(
|
|
chunk_id=2,
|
|
section_continuation=False,
|
|
title=None,
|
|
boost=1,
|
|
recency_bias=0.5,
|
|
score=1.0,
|
|
hidden=False,
|
|
content="Search result 2",
|
|
source_type=DocumentSource.WEB,
|
|
metadata={"id": "doc2"},
|
|
document_id="doc2",
|
|
blurb="Blurb 2",
|
|
semantic_identifier="Semantic ID 2",
|
|
updated_at=datetime(2023, 1, 2),
|
|
source_links={0: "https://example.com/doc2"},
|
|
match_highlights=[],
|
|
),
|
|
chunks=MagicMock(),
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_search_results(
|
|
mock_inference_sections: list[InferenceSection],
|
|
) -> list[LlmDoc]:
|
|
return [
|
|
llm_doc_from_inference_section(section) for section in mock_inference_sections
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
|
return OnyxContexts(
|
|
contexts=[
|
|
OnyxContext(
|
|
content=section.combined_content,
|
|
document_id=section.center_chunk.document_id,
|
|
semantic_identifier=section.center_chunk.semantic_identifier,
|
|
blurb=section.center_chunk.blurb,
|
|
)
|
|
for section in mock_inference_sections
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_search_tool(
|
|
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
|
) -> MagicMock:
|
|
mock_tool = MagicMock(spec=SearchTool)
|
|
mock_tool.name = "search"
|
|
mock_tool.build_tool_message_content.return_value = "search_response"
|
|
mock_tool.get_args_for_non_tool_calling_llm.return_value = DEFAULT_SEARCH_ARGS
|
|
mock_tool.final_result.return_value = [
|
|
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
|
]
|
|
mock_tool.run.return_value = [
|
|
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
|
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
|
]
|
|
mock_tool.tool_definition.return_value = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"description": "Search for information",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "The search query"},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
}
|
|
mock_post_search_tool_prompt_builder = MagicMock(spec=AnswerPromptBuilder)
|
|
mock_post_search_tool_prompt_builder.build.return_value = [
|
|
SystemMessage(content="Updated system prompt"),
|
|
]
|
|
mock_tool.build_next_prompt.return_value = mock_post_search_tool_prompt_builder
|
|
return mock_tool
|