add custom tool testing

This commit is contained in:
pablodanswer 2024-09-20 18:02:55 -07:00
parent c90d7da02d
commit ffc5dd7b49

View File

@ -1,83 +1,131 @@
from unittest.mock import MagicMock
import unittest
from unittest.mock import patch
import pytest
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CustomTool
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse
def test_custom_tool_run():
# Mock MethodSpec
method_spec = MagicMock(MethodSpec)
method_spec.method = "POST"
method_spec.build_url.return_value = "http://test-url.com/api/endpoint"
method_spec.operation_id = "testOperation" # Add this line
method_spec.name = "test"
method_spec.summary = "test"
class TestCustomTool(unittest.TestCase):
"""
Test suite for CustomTool functionality.
This class tests the creation, running, and result handling of custom tools
based on OpenAPI schemas.
"""
# Create CustomTool instance
custom_tool = CustomTool(method_spec, "http://test-url.com")
# Mock the requests.request function
with patch("requests.request") as mock_request:
# Set up the mock response
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
# Run the tool
result = list(custom_tool.run(request_body={"key": "value"}))
# Assert that the request was made with correct parameters
mock_request.assert_called_once_with(
"POST", "http://test-url.com/api/endpoint", json={"key": "value"}
)
# Check the result
assert len(result) == 1
assert result[0].response.tool_result == {"result": "success"}
def test_build_custom_tools_with_dynamic_schema():
openapi_schema = {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"servers": [{"url": "http://test-url.com/CHAT_SESSION_ID"}],
"paths": {
"/endpoint/CHAT_SESSION_ID": {
"post": {
"summary": "Create a new Assistant",
"operationId": "testEndpoint",
"parameters": [],
def setUp(self):
"""
Set up the test environment before each test method.
Initializes an OpenAPI schema and DynamicSchemaInfo for testing.
"""
self.openapi_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Assistants API",
"description": "An API for managing assistants",
},
"servers": [
{"url": "http://localhost:8080/CHAT_SESSION_ID/test/MESSAGE_ID"},
],
"paths": {
"/assistant/{assistant_id}": {
"GET": {
"summary": "Get a specific Assistant",
"operationId": "getAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
},
"POST": {
"summary": "Create a new Assistant",
"operationId": "createAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
"requestBody": {
"required": True,
"content": {
"application/json": {"schema": {"type": "object"}}
},
},
},
}
}
},
}
},
}
validate_openapi_schema(self.openapi_schema)
self.dynamic_schema_info = DynamicSchemaInfo(chat_session_id=10, message_id=20)
dynamic_schema_info = DynamicSchemaInfo(chat_session_id=123, message_id=456)
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_get(self, mock_request):
"""
Test the GET method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the GET request.
"""
tools = build_custom_tools_from_openapi_schema(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
tools = build_custom_tools_from_openapi_schema(openapi_schema, dynamic_schema_info)
print("tools")
print(tools)
result = list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
mock_request.assert_called_once_with("GET", expected_url, json=None)
# assert len(tools) == 1
# assert isinstance(tools[0], CustomTool)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, CUSTOM_TOOL_RESPONSE_ID)
self.assertEqual(result[0].response.tool_name, "getAssistant")
# # Test that the dynamic schema info was applied
# with patch('requests.request') as mock_request:
# mock_response = MagicMock()
# mock_response.json.return_value = {"result": "success"}
# mock_request.return_value = mock_response
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_post(self, mock_request):
"""
Test the POST method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the POST request with the given body.
"""
tools = build_custom_tools_from_openapi_schema(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
# list(tools[0].run())
result = list(tools[1].run(assistant_id="456"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/456"
mock_request.assert_called_once_with("POST", expected_url, json=None)
# mock_request.assert_called_once()
# print(mock_request.call_args)
# # call_args = mock_request.call_args[0]
# assert 123 in call_args[1] # URL should contain the session ID
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, CUSTOM_TOOL_RESPONSE_ID)
self.assertEqual(result[0].response.tool_name, "createAssistant")
def test_custom_tool_final_result(self):
"""
Test the final_result method of a custom tool.
Verifies that the method correctly extracts and returns the tool result.
"""
tools = build_custom_tools_from_openapi_schema(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
mock_response = ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name="getAssistant",
tool_result={"id": "789", "name": "Final Assistant"},
),
)
final_result = tools[0].final_result(mock_response)
self.assertEqual(final_result, {"id": "789", "name": "Final Assistant"})
if __name__ == "__main__":