mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 08:47:41 +01:00
Add Tool Call ID (#12956)
* routes/types: add tool call id --------- Co-authored-by: ParthSareen <parth.sareen@ollama.com>
This commit is contained in:
10
api/types.go
10
api/types.go
@@ -181,10 +181,11 @@ type Message struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
// Thinking contains the text that was inside thinking tags in the
|
// Thinking contains the text that was inside thinking tags in the
|
||||||
// original model output when ChatRequest.Think is enabled.
|
// original model output when ChatRequest.Think is enabled.
|
||||||
Thinking string `json:"thinking,omitempty"`
|
Thinking string `json:"thinking,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolName string `json:"tool_name,omitempty"`
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
@@ -200,6 +201,7 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
Function ToolCallFunction `json:"function"`
|
Function ToolCallFunction `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -229,20 +228,11 @@ func ToUsage(r api.ChatResponse) Usage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toolCallId() string {
|
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
b := make([]byte, 8)
|
|
||||||
for i := range b {
|
|
||||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
||||||
}
|
|
||||||
return "call_" + strings.ToLower(string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
|
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
|
||||||
func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
toolCalls := make([]ToolCall, len(tc))
|
toolCalls := make([]ToolCall, len(tc))
|
||||||
for i, tc := range tc {
|
for i, tc := range tc {
|
||||||
toolCalls[i].ID = toolCallId()
|
toolCalls[i].ID = tc.ID
|
||||||
toolCalls[i].Type = "function"
|
toolCalls[i].Type = "function"
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
toolCalls[i].Index = tc.Function.Index
|
toolCalls[i].Index = tc.Function.Index
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -148,3 +150,71 @@ func TestNewError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||||
|
original := []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc123",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 2,
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Seattle",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "call_def456",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 7,
|
||||||
|
Name: "get_time",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"timezone": "UTC",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := make([]api.ToolCall, len(original))
|
||||||
|
copy(toolCalls, original)
|
||||||
|
got := ToToolCalls(toolCalls)
|
||||||
|
|
||||||
|
if len(got) != len(original) {
|
||||||
|
t.Fatalf("expected %d tool calls, got %d", len(original), len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc123",
|
||||||
|
Type: "function",
|
||||||
|
Index: 2,
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"location":"Seattle"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "call_def456",
|
||||||
|
Type: "function",
|
||||||
|
Index: 7,
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_time",
|
||||||
|
Arguments: `{"timezone":"UTC"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expected, got); diff != "" {
|
||||||
|
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(original, toolCalls); diff != "" {
|
||||||
|
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -1812,6 +1813,15 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolCallId() string {
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, 8)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return "call_" + strings.ToLower(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
@@ -2130,6 +2140,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
res.Message.Content = content
|
res.Message.Content = content
|
||||||
res.Message.Thinking = thinking
|
res.Message.Thinking = thinking
|
||||||
|
for i := range toolCalls {
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
}
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
|
|
||||||
tb.WriteString(thinking)
|
tb.WriteString(thinking)
|
||||||
@@ -2174,6 +2187,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
if len(content) > 0 {
|
if len(content) > 0 {
|
||||||
res.Message.Content = content
|
res.Message.Content = content
|
||||||
} else if len(toolCalls) > 0 {
|
} else if len(toolCalls) > 0 {
|
||||||
|
for i := range toolCalls {
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
}
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
} else if res.Message.Thinking != "" {
|
} else if res.Message.Thinking != "" {
|
||||||
|
|||||||
@@ -554,6 +554,14 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Error("expected tool calls, got nil")
|
t.Error("expected tool calls, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gotToolCall := resp.Message.ToolCalls[0]
|
||||||
|
if gotToolCall.ID == "" {
|
||||||
|
t.Error("expected tool call ID to be populated")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(gotToolCall.ID, "call_") {
|
||||||
|
t.Errorf("expected tool call ID to have call_ prefix, got %q", gotToolCall.ID)
|
||||||
|
}
|
||||||
|
|
||||||
expectedToolCall := api.ToolCall{
|
expectedToolCall := api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
@@ -564,7 +572,8 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
expectedToolCall.ID = gotToolCall.ID
|
||||||
|
if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" {
|
||||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -669,6 +678,17 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(resp.Message.ToolCalls) > 0 {
|
||||||
|
for _, call := range resp.Message.ToolCalls {
|
||||||
|
if call.ID == "" {
|
||||||
|
t.Fatal("expected streaming tool call to have an ID")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(call.ID, "call_") {
|
||||||
|
t.Fatalf("expected streaming tool call ID to have call_ prefix, got %q", call.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
if len(resp.Message.ToolCalls) != 1 {
|
if len(resp.Message.ToolCalls) != 1 {
|
||||||
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||||
@@ -687,6 +707,14 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if finalToolCall.ID == "" {
|
||||||
|
t.Fatal("expected final tool call to have an ID")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(finalToolCall.ID, "call_") {
|
||||||
|
t.Fatalf("expected final tool call ID to have call_ prefix, got %q", finalToolCall.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedToolCall.ID = finalToolCall.ID
|
||||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user