diff --git a/server/pkg/agent/claude.go b/server/pkg/agent/claude.go index ae040201e..2a73ef254 100644 --- a/server/pkg/agent/claude.go +++ b/server/pkg/agent/claude.go @@ -171,6 +171,9 @@ func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOpt output.Reset() output.WriteString(msg.ResultText) } + if resultUsage := claudeResultUsage(msg, opts.Model); len(resultUsage) > 0 { + usage = resultUsage + } if msg.IsError { finalStatus = "failed" finalError = msg.ResultText @@ -341,12 +344,15 @@ type claudeSDKMessage struct { Message json.RawMessage `json:"message,omitempty"` Subtype string `json:"subtype,omitempty"` SessionID string `json:"session_id,omitempty"` + Model string `json:"model,omitempty"` // result fields - ResultText string `json:"result,omitempty"` - IsError bool `json:"is_error,omitempty"` - DurationMs float64 `json:"duration_ms,omitempty"` - NumTurns int `json:"num_turns,omitempty"` + ResultText string `json:"result,omitempty"` + IsError bool `json:"is_error,omitempty"` + DurationMs float64 `json:"duration_ms,omitempty"` + NumTurns int `json:"num_turns,omitempty"` + Usage *claudeUsage `json:"usage,omitempty"` + ModelUsage map[string]claudeResultModelUsage `json:"modelUsage,omitempty"` // log fields Log *claudeLogEntry `json:"log,omitempty"` @@ -375,6 +381,58 @@ type claudeUsage struct { CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` } +type claudeResultModelUsage struct { + InputTokens int64 `json:"inputTokens"` + OutputTokens int64 `json:"outputTokens"` + CacheReadInputTokens int64 `json:"cacheReadInputTokens"` + CacheCreationInputTokens int64 `json:"cacheCreationInputTokens"` +} + +func claudeResultUsage(msg claudeSDKMessage, fallbackModel string) map[string]TokenUsage { + if len(msg.ModelUsage) > 0 { + usage := make(map[string]TokenUsage, len(msg.ModelUsage)) + for model, u := range msg.ModelUsage { + if model == "" || !claudeUsageHasTokens(u.InputTokens, u.OutputTokens, u.CacheReadInputTokens, u.CacheCreationInputTokens) { + continue + } + usage[model] = TokenUsage{ + InputTokens: u.InputTokens, + OutputTokens: u.OutputTokens, + CacheReadTokens: u.CacheReadInputTokens, + CacheWriteTokens: u.CacheCreationInputTokens, + } + } + if len(usage) > 0 { + return usage + } + } + + model := msg.Model + if model == "" { + model = fallbackModel + } + if msg.Usage == nil || model == "" || !claudeUsageHasTokens( + msg.Usage.InputTokens, + msg.Usage.OutputTokens, + msg.Usage.CacheReadInputTokens, + msg.Usage.CacheCreationInputTokens, + ) { + return nil + } + return map[string]TokenUsage{ + model: { + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, + CacheWriteTokens: msg.Usage.CacheCreationInputTokens, + }, + } +} + +func claudeUsageHasTokens(input, output, cacheRead, cacheWrite int64) bool { + return input > 0 || output > 0 || cacheRead > 0 || cacheWrite > 0 +} + type claudeContentBlock struct { Type string `json:"type"` Text string `json:"text,omitempty"` diff --git a/server/pkg/agent/claude_test.go b/server/pkg/agent/claude_test.go index e0076cd1d..f96dee036 100644 --- a/server/pkg/agent/claude_test.go +++ b/server/pkg/agent/claude_test.go @@ -595,6 +595,52 @@ func TestClaudeExecuteSurfacesStderrWhenChildExitsEarly(t *testing.T) { } } +func TestClaudeExecuteRecordsResultModelUsage(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("shell-script fixture is POSIX-only") + } + + fakePath := filepath.Join(t.TempDir(), "claude") + script := "#!/bin/sh\n" + + "cat >/dev/null\n" + + "printf '%s\\n' '{\"type\":\"system\",\"session_id\":\"sess-result-usage\"}'\n" + + "printf '%s\\n' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"session_id\":\"sess-result-usage\",\"result\":\"done\",\"modelUsage\":{\"zhipu/coding-plan\":{\"inputTokens\":123,\"outputTokens\":45,\"cacheReadInputTokens\":7,\"cacheCreationInputTokens\":11,\"costUSD\":0.01}}}'\n" + writeTestExecutable(t, fakePath, []byte(script)) + + backend, err := New("claude", Config{ExecutablePath: fakePath, Logger: slog.Default()}) + if err != nil { + t.Fatalf("new claude backend: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + session, err := backend.Execute(ctx, "prompt-ignored", ExecOptions{Timeout: 5 * time.Second}) + if err != nil { + t.Fatalf("execute: %v", err) + } + go func() { + for range session.Messages { + } + }() + + select { + case result, ok := <-session.Result: + if !ok { + t.Fatal("result channel closed without a value") + } + usage, ok := result.Usage["zhipu/coding-plan"] + if !ok { + t.Fatalf("expected usage for zhipu/coding-plan, got %#v", result.Usage) + } + if usage.InputTokens != 123 || usage.OutputTokens != 45 || usage.CacheReadTokens != 7 || usage.CacheWriteTokens != 11 { + t.Fatalf("unexpected usage: %+v", usage) + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for result") + } +} + func mustMarshal(t *testing.T, v any) json.RawMessage { t.Helper() data, err := json.Marshal(v)