mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-11 13:46:07 +02:00
Fix parallel tool calls (#2779)
* Fix parallel tool calls * remove comments
This commit is contained in:
@@ -109,7 +109,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
},
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"index": 0, # only support a single tool call atm
|
"index": tool_call.get("index", 0),
|
||||||
}
|
}
|
||||||
for tool_call in message.tool_calls
|
for tool_call in message.tool_calls
|
||||||
]
|
]
|
||||||
@@ -158,12 +158,13 @@ def _convert_delta_to_message_chunk(
|
|||||||
if tool_calls:
|
if tool_calls:
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||||
|
idx = tool_call.index
|
||||||
|
|
||||||
tool_call_chunk = ToolCallChunk(
|
tool_call_chunk = ToolCallChunk(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
id=tool_call.id,
|
id=tool_call.id,
|
||||||
args=tool_call.function.arguments,
|
args=tool_call.function.arguments,
|
||||||
index=0, # only support a single tool call atm
|
index=idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AIMessageChunk(
|
return AIMessageChunk(
|
||||||
|
290
backend/tests/unit/danswer/llm/test_chat_llm.py
Normal file
290
backend/tests/unit/danswer/llm/test_chat_llm.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages import AIMessageChunk
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||||
|
from litellm.types.utils import Delta
|
||||||
|
from litellm.types.utils import Function as LiteLLMFunction
|
||||||
|
|
||||||
|
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||||
|
|
||||||
|
|
||||||
|
def _create_delta(
|
||||||
|
role: str | None = None,
|
||||||
|
content: str | None = None,
|
||||||
|
tool_calls: list[ChatCompletionDeltaToolCall] | None = None,
|
||||||
|
) -> Delta:
|
||||||
|
delta = Delta(role=role, content=content)
|
||||||
|
# NOTE: for some reason, if you pass tool_calls to the constructor, it doesn't actually
|
||||||
|
# get set, so we have to do it this way
|
||||||
|
delta.tool_calls = tool_calls
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_multi_llm() -> DefaultMultiLLM:
|
||||||
|
return DefaultMultiLLM(
|
||||||
|
api_key="test_key",
|
||||||
|
timeout=30,
|
||||||
|
model_provider="openai",
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
|
||||||
|
# Mock the litellm.completion function
|
||||||
|
with patch("danswer.llm.chat_llm.litellm.completion") as mock_completion:
|
||||||
|
# Create a mock response with multiple tool calls using litellm objects
|
||||||
|
mock_response = litellm.ModelResponse(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=litellm.Message(
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
litellm.ChatCompletionMessageToolCall(
|
||||||
|
id="call_1",
|
||||||
|
function=LiteLLMFunction(
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"location": "New York"}',
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
litellm.ChatCompletionMessageToolCall(
|
||||||
|
id="call_2",
|
||||||
|
function=LiteLLMFunction(
|
||||||
|
name="get_time", arguments='{"timezone": "EST"}'
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
usage=litellm.Usage(
|
||||||
|
prompt_tokens=50, completion_tokens=30, total_tokens=80
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
# Define input messages
|
||||||
|
messages = [HumanMessage(content="What's the weather and time in New York?")]
|
||||||
|
|
||||||
|
# Define available tools
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"location": {"type": "string"}},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"description": "Get the current time for a timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"timezone": {"type": "string"}},
|
||||||
|
"required": ["timezone"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the _invoke_implementation method
|
||||||
|
result = default_multi_llm.invoke(messages, tools)
|
||||||
|
|
||||||
|
# Assert that the result is an AIMessage
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
|
# Assert that the content is None (as per the mock response)
|
||||||
|
assert result.content == ""
|
||||||
|
|
||||||
|
# Assert that there are two tool calls
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
|
||||||
|
# Assert the details of the first tool call
|
||||||
|
assert result.tool_calls[0]["id"] == "call_1"
|
||||||
|
assert result.tool_calls[0]["name"] == "get_weather"
|
||||||
|
assert result.tool_calls[0]["args"] == {"location": "New York"}
|
||||||
|
|
||||||
|
# Assert the details of the second tool call
|
||||||
|
assert result.tool_calls[1]["id"] == "call_2"
|
||||||
|
assert result.tool_calls[1]["name"] == "get_time"
|
||||||
|
assert result.tool_calls[1]["args"] == {"timezone": "EST"}
|
||||||
|
|
||||||
|
# Verify that litellm.completion was called with the correct arguments
|
||||||
|
mock_completion.assert_called_once_with(
|
||||||
|
model="openai/gpt-3.5-turbo",
|
||||||
|
api_key="test_key",
|
||||||
|
base_url=None,
|
||||||
|
api_version=None,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=None,
|
||||||
|
stream=False,
|
||||||
|
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||||
|
timeout=30,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> None:
|
||||||
|
# Mock the litellm.completion function
|
||||||
|
with patch("danswer.llm.chat_llm.litellm.completion") as mock_completion:
|
||||||
|
# Create a mock response with multiple tool calls using litellm objects
|
||||||
|
mock_response = [
|
||||||
|
litellm.ModelResponse(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
delta=_create_delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="call_1",
|
||||||
|
function=LiteLLMFunction(
|
||||||
|
name="get_weather", arguments='{"location": '
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
),
|
||||||
|
litellm.ModelResponse(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
delta=_create_delta(
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="",
|
||||||
|
function=LiteLLMFunction(arguments='"New York"}'),
|
||||||
|
type="function",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
),
|
||||||
|
litellm.ModelResponse(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
delta=_create_delta(
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
id="call_2",
|
||||||
|
function=LiteLLMFunction(
|
||||||
|
name="get_time", arguments='{"timezone": "EST"}'
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
index=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
# Define input messages and tools (same as in the non-streaming test)
|
||||||
|
messages = [HumanMessage(content="What's the weather and time in New York?")]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"location": {"type": "string"}},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"description": "Get the current time for a timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"timezone": {"type": "string"}},
|
||||||
|
"required": ["timezone"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the stream method
|
||||||
|
stream_result = list(default_multi_llm.stream(messages, tools))
|
||||||
|
|
||||||
|
# Assert that we received the correct number of chunks
|
||||||
|
assert len(stream_result) == 3
|
||||||
|
|
||||||
|
# Combine all chunks into a single AIMessage
|
||||||
|
combined_result: AIMessage = AIMessageChunk(content="")
|
||||||
|
for chunk in stream_result:
|
||||||
|
combined_result += chunk # type: ignore
|
||||||
|
|
||||||
|
# Assert that the combined result matches our expectations
|
||||||
|
assert isinstance(combined_result, AIMessage)
|
||||||
|
assert combined_result.content == ""
|
||||||
|
assert len(combined_result.tool_calls) == 2
|
||||||
|
assert combined_result.tool_calls[0]["id"] == "call_1"
|
||||||
|
assert combined_result.tool_calls[0]["name"] == "get_weather"
|
||||||
|
assert combined_result.tool_calls[0]["args"] == {"location": "New York"}
|
||||||
|
assert combined_result.tool_calls[1]["id"] == "call_2"
|
||||||
|
assert combined_result.tool_calls[1]["name"] == "get_time"
|
||||||
|
assert combined_result.tool_calls[1]["args"] == {"timezone": "EST"}
|
||||||
|
|
||||||
|
# Verify that litellm.completion was called with the correct arguments
|
||||||
|
mock_completion.assert_called_once_with(
|
||||||
|
model="openai/gpt-3.5-turbo",
|
||||||
|
api_key="test_key",
|
||||||
|
base_url=None,
|
||||||
|
api_version=None,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=None,
|
||||||
|
stream=True,
|
||||||
|
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||||
|
timeout=30,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
)
|
Reference in New Issue
Block a user