mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 19:58:29 +01:00
tools: parse tool calls that don't conform to ("name": name, "arguments": args} (#12738)
This commit is contained in:
@@ -125,7 +125,7 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var args map[string]any
|
var args map[string]any
|
||||||
if found, i := findArguments(p.buffer); found == nil {
|
if found, i := findArguments(tool, p.buffer); found == nil {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
args = found
|
args = found
|
||||||
@@ -219,7 +219,7 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
|
|||||||
// objects for functions that have all-optional parameters
|
// objects for functions that have all-optional parameters
|
||||||
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||||
// `{"name": "get_conditions"}` will not currently work
|
// `{"name": "get_conditions"}` will not currently work
|
||||||
func findArguments(buffer []byte) (map[string]any, int) {
|
func findArguments(tool *api.Tool, buffer []byte) (map[string]any, int) {
|
||||||
if len(buffer) == 0 {
|
if len(buffer) == 0 {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
@@ -269,27 +269,30 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||||||
|
|
||||||
var findObject func(obj map[string]any) (map[string]any, bool)
|
var findObject func(obj map[string]any) (map[string]any, bool)
|
||||||
findObject = func(obj map[string]any) (map[string]any, bool) {
|
findObject = func(obj map[string]any) (map[string]any, bool) {
|
||||||
|
findMap := func(name string, obj map[string]any) (map[string]any, bool) {
|
||||||
|
if args, ok := obj[name].(map[string]any); ok {
|
||||||
|
return args, true
|
||||||
|
}
|
||||||
|
if argsStr, ok := obj[name].(string); ok {
|
||||||
|
var argsData map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||||
|
return argsData, ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
if _, hasName := obj["name"]; hasName {
|
if _, hasName := obj["name"]; hasName {
|
||||||
if args, ok := obj["arguments"].(map[string]any); ok {
|
if args, ok := findMap("arguments", obj); ok {
|
||||||
return args, true
|
return args, true
|
||||||
}
|
}
|
||||||
if argsStr, ok := obj["arguments"].(string); ok {
|
if args, ok := findMap("parameters", obj); ok {
|
||||||
var argsData map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
|
||||||
return argsData, ok
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if args, ok := obj["parameters"].(map[string]any); ok {
|
|
||||||
return args, true
|
return args, true
|
||||||
}
|
}
|
||||||
if argsStr, ok := obj["parameters"].(string); ok {
|
|
||||||
var argsData map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
|
||||||
return argsData, ok
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
if args, ok := findMap(tool.Function.Name, obj); ok {
|
||||||
|
return args, true
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range obj {
|
for _, v := range obj {
|
||||||
switch child := v.(type) {
|
switch child := v.(type) {
|
||||||
|
|||||||
@@ -1033,6 +1033,7 @@ func TestFindArguments(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
buffer []byte
|
buffer []byte
|
||||||
want map[string]any
|
want map[string]any
|
||||||
|
tool string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty string",
|
name: "empty string",
|
||||||
@@ -1290,11 +1291,29 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "simple tool call",
|
||||||
|
tool: "get_temperature",
|
||||||
|
buffer: []byte(`{"get_temperature": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stringified simple tool call",
|
||||||
|
tool: "get_temperature",
|
||||||
|
buffer: []byte(`{"get_temperature": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _ := findArguments(tt.buffer)
|
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||||
|
|||||||
Reference in New Issue
Block a user