From 1e85eb0aacf2ff5f3e50aa38ebd95b486bab9968 Mon Sep 17 00:00:00 2001 From: ZeroIce <39822906+vicksiyi@users.noreply.github.com> Date: Fri, 3 Jul 2026 12:00:08 +0800 Subject: [PATCH] Fix Kiro ACP usage accounting (#4867) Co-authored-by: multica-agent --- server/pkg/agent/hermes.go | 97 ++++++++++++++++++++++++--------- server/pkg/agent/hermes_test.go | 24 ++++++++ server/pkg/agent/kimi.go | 2 +- server/pkg/agent/kiro.go | 13 ++++- server/pkg/agent/kiro_test.go | 47 +++++++++++++++- server/pkg/agent/qoder.go | 2 +- server/pkg/agent/traecli.go | 2 +- 7 files changed, 152 insertions(+), 35 deletions(-) diff --git a/server/pkg/agent/hermes.go b/server/pkg/agent/hermes.go index 841dbda72..7cf5c193a 100644 --- a/server/pkg/agent/hermes.go +++ b/server/pkg/agent/hermes.go @@ -424,7 +424,7 @@ func (b *hermesBackend) Execute(ctx context.Context, prompt string, opts ExecOpt c.usageMu.Unlock() var usageMap map[string]TokenUsage - if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 { + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { model := effectiveModel if model == "" { model = "unknown" @@ -732,14 +732,8 @@ func (c *hermesClient) handleResponse(raw map[string]json.RawMessage) { func (c *hermesClient) extractPromptResult(data json.RawMessage) { var resp struct { - StopReason string `json:"stopReason"` - Usage *struct { - InputTokens int64 `json:"inputTokens"` - OutputTokens int64 `json:"outputTokens"` - TotalTokens int64 `json:"totalTokens"` - ThoughtTokens int64 `json:"thoughtTokens"` - CachedReadTokens int64 `json:"cachedReadTokens"` - } `json:"usage"` + StopReason string `json:"stopReason"` + Usage json.RawMessage `json:"usage"` } if err := json.Unmarshal(data, &resp); err != nil { return @@ -748,12 +742,8 @@ func (c *hermesClient) extractPromptResult(data json.RawMessage) { pr := hermesPromptResult{ stopReason: resp.StopReason, } - if resp.Usage != nil { - pr.usage = TokenUsage{ - InputTokens: resp.Usage.InputTokens, - OutputTokens: resp.Usage.OutputTokens, - CacheReadTokens: resp.Usage.CachedReadTokens, - } + if len(resp.Usage) > 0 && string(resp.Usage) != "null" { + pr.usage = parseACPTokenUsage(resp.Usage) } if c.onPromptDone != nil { @@ -1190,31 +1180,84 @@ func extractACPToolCallText(blocks []json.RawMessage) string { func (c *hermesClient) handleUsageUpdate(data json.RawMessage) { var msg struct { - Usage struct { - InputTokens int64 `json:"inputTokens"` - OutputTokens int64 `json:"outputTokens"` - TotalTokens int64 `json:"totalTokens"` - CachedReadTokens int64 `json:"cachedReadTokens"` - } `json:"usage"` + Usage json.RawMessage `json:"usage"` } if err := json.Unmarshal(data, &msg); err != nil { return } + usage := parseACPTokenUsage(msg.Usage) c.usageMu.Lock() // Usage updates from ACP are cumulative snapshots, so take the latest. - if msg.Usage.InputTokens > c.usage.InputTokens { - c.usage.InputTokens = msg.Usage.InputTokens + if usage.InputTokens > c.usage.InputTokens { + c.usage.InputTokens = usage.InputTokens } - if msg.Usage.OutputTokens > c.usage.OutputTokens { - c.usage.OutputTokens = msg.Usage.OutputTokens + if usage.OutputTokens > c.usage.OutputTokens { + c.usage.OutputTokens = usage.OutputTokens } - if msg.Usage.CachedReadTokens > c.usage.CacheReadTokens { - c.usage.CacheReadTokens = msg.Usage.CachedReadTokens + if usage.CacheReadTokens > c.usage.CacheReadTokens { + c.usage.CacheReadTokens = usage.CacheReadTokens + } + if usage.CacheWriteTokens > c.usage.CacheWriteTokens { + c.usage.CacheWriteTokens = usage.CacheWriteTokens } c.usageMu.Unlock() } +func parseACPTokenUsage(data json.RawMessage) TokenUsage { + if len(data) == 0 || string(data) == "null" { + return TokenUsage{} + } + var fields map[string]json.RawMessage + if err := json.Unmarshal(data, &fields); err != nil { + return TokenUsage{} + } + return TokenUsage{ + InputTokens: acpUsageInt64(fields, "inputTokens", "input_tokens"), + OutputTokens: acpUsageInt64(fields, "outputTokens", "output_tokens"), + CacheReadTokens: acpUsageInt64(fields, + "cachedReadTokens", + "cacheReadTokens", + "cached_input_tokens", + "cache_read_tokens", + "cache_read_input_tokens", + ), + CacheWriteTokens: acpUsageInt64(fields, + "cachedWriteTokens", + "cacheWriteTokens", + "cache_write_tokens", + "cache_creation_input_tokens", + ), + } +} + +func acpUsageInt64(fields map[string]json.RawMessage, names ...string) int64 { + for _, name := range names { + raw, ok := fields[name] + if !ok { + continue + } + var n json.Number + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + if err := dec.Decode(&n); err == nil { + if v, err := n.Int64(); err == nil { + return v + } + if f, err := n.Float64(); err == nil { + return int64(f) + } + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + if v, err := strconv.ParseInt(strings.TrimSpace(s), 10, 64); err == nil { + return v + } + } + } + return 0 +} + // ── Helpers ── // extractACPSessionID pulls `sessionId` out of a session/new or diff --git a/server/pkg/agent/hermes_test.go b/server/pkg/agent/hermes_test.go index bf8a0d26a..0605e3eae 100644 --- a/server/pkg/agent/hermes_test.go +++ b/server/pkg/agent/hermes_test.go @@ -859,6 +859,30 @@ func TestHermesClientHandleSessionNotificationTurnEnd(t *testing.T) { } } +func TestParseACPTokenUsageAliases(t *testing.T) { + t.Parallel() + + usage := parseACPTokenUsage(json.RawMessage(`{ + "input_tokens": 11, + "output_tokens": "7", + "cacheReadTokens": 5, + "cache_creation_input_tokens": 3 + }`)) + + if usage.InputTokens != 11 { + t.Errorf("InputTokens: got %d, want 11", usage.InputTokens) + } + if usage.OutputTokens != 7 { + t.Errorf("OutputTokens: got %d, want 7", usage.OutputTokens) + } + if usage.CacheReadTokens != 5 { + t.Errorf("CacheReadTokens: got %d, want 5", usage.CacheReadTokens) + } + if usage.CacheWriteTokens != 3 { + t.Errorf("CacheWriteTokens: got %d, want 3", usage.CacheWriteTokens) + } +} + func TestHermesClientHandleToolCallComplete(t *testing.T) { t.Parallel() diff --git a/server/pkg/agent/kimi.go b/server/pkg/agent/kimi.go index db266e4f5..3ccc476d3 100644 --- a/server/pkg/agent/kimi.go +++ b/server/pkg/agent/kimi.go @@ -375,7 +375,7 @@ func (b *kimiBackend) Execute(ctx context.Context, prompt string, opts ExecOptio c.usageMu.Unlock() var usageMap map[string]TokenUsage - if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 { + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { model := opts.Model if model == "" { model = "unknown" diff --git a/server/pkg/agent/kiro.go b/server/pkg/agent/kiro.go index ae18bb6a6..00a644e2f 100644 --- a/server/pkg/agent/kiro.go +++ b/server/pkg/agent/kiro.go @@ -189,6 +189,7 @@ func (b *kiroBackend) Execute(ctx context.Context, prompt string, opts ExecOptio finalStatus := "completed" var finalError string var sessionID string + effectiveModel := strings.TrimSpace(opts.Model) initResult, err := c.request(runCtx, "initialize", map[string]any{ "protocolVersion": 1, @@ -245,6 +246,9 @@ func (b *kiroBackend) Execute(ctx context.Context, prompt string, opts ExecOptio "actual", sessionID, ) } + if effectiveModel == "" { + effectiveModel = extractACPCurrentModelID(result) + } } else { result, err := c.request(runCtx, "session/new", map[string]any{ "cwd": cwd, @@ -263,6 +267,9 @@ func (b *kiroBackend) Execute(ctx context.Context, prompt string, opts ExecOptio resCh <- Result{Status: finalStatus, Error: finalError, DurationMs: time.Since(startTime).Milliseconds()} return } + if effectiveModel == "" { + effectiveModel = extractACPCurrentModelID(result) + } } c.sessionID = sessionID @@ -354,6 +361,8 @@ func (b *kiroBackend) Execute(ctx context.Context, prompt string, opts ExecOptio c.usageMu.Lock() c.usage.InputTokens += pr.usage.InputTokens c.usage.OutputTokens += pr.usage.OutputTokens + c.usage.CacheReadTokens += pr.usage.CacheReadTokens + c.usage.CacheWriteTokens += pr.usage.CacheWriteTokens c.usageMu.Unlock() default: } @@ -387,8 +396,8 @@ func (b *kiroBackend) Execute(ctx context.Context, prompt string, opts ExecOptio c.usageMu.Unlock() var usageMap map[string]TokenUsage - if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 { - model := opts.Model + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { + model := effectiveModel if model == "" { model = "unknown" } diff --git a/server/pkg/agent/kiro_test.go b/server/pkg/agent/kiro_test.go index 8a4668dd5..60f341e6d 100644 --- a/server/pkg/agent/kiro_test.go +++ b/server/pkg/agent/kiro_test.go @@ -106,7 +106,7 @@ while IFS= read -r line; do esac printf '{"jsonrpc":"2.0","method":"session/notification","params":{"sessionId":"ses_loaded","update":{"type":"ToolCallUpdate","toolCallId":"tc-current","status":"completed","name":"Shell","parameters":{"command":"echo current"},"output":"current tool output\\n"}}}\n' printf '{"jsonrpc":"2.0","method":"session/notification","params":{"sessionId":"ses_loaded","update":{"type":"AgentMessageChunk","content":{"type":"text","text":"loaded"}}}}\n' - printf '{"jsonrpc":"2.0","id":%s,"result":{"stopReason":"end_turn","usage":{"inputTokens":2,"outputTokens":1}}}\n' "$id" + printf '{"jsonrpc":"2.0","id":%s,"result":{"stopReason":"end_turn","usage":{"inputTokens":2,"outputTokens":1,"cacheReadTokens":7,"cacheWriteTokens":3}}}\n' "$id" exit 0 ;; esac @@ -162,6 +162,47 @@ func TestKiroBackendSetModelFailureFailsTask(t *testing.T) { } } +func TestKiroBackendAttributesUsageToCurrentModel(t *testing.T) { + t.Parallel() + + fakePath := filepath.Join(t.TempDir(), "kiro-cli") + writeTestExecutable(t, fakePath, []byte(fakeKiroACPScript())) + + backend, err := New("kiro", Config{ExecutablePath: fakePath, Logger: slog.Default()}) + if err != nil { + t.Fatalf("new kiro backend: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + session, err := backend.Execute(ctx, "prompt-ignored", ExecOptions{ + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + go func() { + for range session.Messages { + } + }() + + result := <-session.Result + if result.Status != "completed" { + t.Fatalf("expected completed result, got status=%q error=%q", result.Status, result.Error) + } + if _, ok := result.Usage["unknown"]; ok { + t.Fatalf("usage should use Kiro current model, got unknown entry: %+v", result.Usage) + } + usage, ok := result.Usage["auto"] + if !ok { + t.Fatalf("expected usage under current model auto, got %+v", result.Usage) + } + if usage.InputTokens != 2 || usage.OutputTokens != 1 || usage.CacheReadTokens != 7 || usage.CacheWriteTokens != 3 { + t.Fatalf("usage = %+v, want input=2 output=1 cache_read=7 cache_write=3", usage) + } +} + func fakeKiroACPGoalCompleteCloseErrorScript(goalStatus string) string { return `#!/bin/sh while IFS= read -r line; do @@ -551,8 +592,8 @@ func TestKiroBackendUsesSessionLoadForResume(t *testing.T) { if result.Output != "loaded" { t.Fatalf("output = %q, want loaded", result.Output) } - if usage := result.Usage["unknown"]; usage.InputTokens != 2 || usage.OutputTokens != 1 || usage.CacheReadTokens != 0 { - t.Fatalf("usage = %+v, want input=2 output=1 cache_read=0", usage) + if usage := result.Usage["unknown"]; usage.InputTokens != 2 || usage.OutputTokens != 1 || usage.CacheReadTokens != 7 || usage.CacheWriteTokens != 3 { + t.Fatalf("usage = %+v, want input=2 output=1 cache_read=7 cache_write=3", usage) } if len(messages) != 3 { t.Fatalf("messages = %+v, want current tool use, tool result, and text only", messages) diff --git a/server/pkg/agent/qoder.go b/server/pkg/agent/qoder.go index 6a0a4b5f9..b04d49cfe 100644 --- a/server/pkg/agent/qoder.go +++ b/server/pkg/agent/qoder.go @@ -414,7 +414,7 @@ func (b *qoderBackend) Execute(ctx context.Context, prompt string, opts ExecOpti c.usageMu.Unlock() var usageMap map[string]TokenUsage - if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 { + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { model := effectiveModel if model == "" { model = "unknown" diff --git a/server/pkg/agent/traecli.go b/server/pkg/agent/traecli.go index 568d415b9..8e39f5eb8 100644 --- a/server/pkg/agent/traecli.go +++ b/server/pkg/agent/traecli.go @@ -417,7 +417,7 @@ func (b *traecliBackend) Execute(ctx context.Context, prompt string, opts ExecOp c.usageMu.Unlock() var usageMap map[string]TokenUsage - if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 { + if u.InputTokens > 0 || u.OutputTokens > 0 || u.CacheReadTokens > 0 || u.CacheWriteTokens > 0 { model := effectiveModel if model == "" { model = "unknown"