fix(claude): record result model usage (#2899)

This commit is contained in:
YOMXXX
2026-05-21 13:00:12 +08:00
committed by GitHub
parent 2f1f90c11a
commit ed2957ddf8
2 changed files with 108 additions and 4 deletions

View File

@@ -171,6 +171,9 @@ func (b *claudeBackend) Execute(ctx context.Context, prompt string, opts ExecOpt
output.Reset()
output.WriteString(msg.ResultText)
}
if resultUsage := claudeResultUsage(msg, opts.Model); len(resultUsage) > 0 {
usage = resultUsage
}
if msg.IsError {
finalStatus = "failed"
finalError = msg.ResultText
@@ -341,12 +344,15 @@ type claudeSDKMessage struct {
Message json.RawMessage `json:"message,omitempty"`
Subtype string `json:"subtype,omitempty"`
SessionID string `json:"session_id,omitempty"`
Model string `json:"model,omitempty"`
// result fields
ResultText string `json:"result,omitempty"`
IsError bool `json:"is_error,omitempty"`
DurationMs float64 `json:"duration_ms,omitempty"`
NumTurns int `json:"num_turns,omitempty"`
ResultText string `json:"result,omitempty"`
IsError bool `json:"is_error,omitempty"`
DurationMs float64 `json:"duration_ms,omitempty"`
NumTurns int `json:"num_turns,omitempty"`
Usage *claudeUsage `json:"usage,omitempty"`
ModelUsage map[string]claudeResultModelUsage `json:"modelUsage,omitempty"`
// log fields
Log *claudeLogEntry `json:"log,omitempty"`
@@ -375,6 +381,58 @@ type claudeUsage struct {
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
}
type claudeResultModelUsage struct {
InputTokens int64 `json:"inputTokens"`
OutputTokens int64 `json:"outputTokens"`
CacheReadInputTokens int64 `json:"cacheReadInputTokens"`
CacheCreationInputTokens int64 `json:"cacheCreationInputTokens"`
}
func claudeResultUsage(msg claudeSDKMessage, fallbackModel string) map[string]TokenUsage {
if len(msg.ModelUsage) > 0 {
usage := make(map[string]TokenUsage, len(msg.ModelUsage))
for model, u := range msg.ModelUsage {
if model == "" || !claudeUsageHasTokens(u.InputTokens, u.OutputTokens, u.CacheReadInputTokens, u.CacheCreationInputTokens) {
continue
}
usage[model] = TokenUsage{
InputTokens: u.InputTokens,
OutputTokens: u.OutputTokens,
CacheReadTokens: u.CacheReadInputTokens,
CacheWriteTokens: u.CacheCreationInputTokens,
}
}
if len(usage) > 0 {
return usage
}
}
model := msg.Model
if model == "" {
model = fallbackModel
}
if msg.Usage == nil || model == "" || !claudeUsageHasTokens(
msg.Usage.InputTokens,
msg.Usage.OutputTokens,
msg.Usage.CacheReadInputTokens,
msg.Usage.CacheCreationInputTokens,
) {
return nil
}
return map[string]TokenUsage{
model: {
InputTokens: msg.Usage.InputTokens,
OutputTokens: msg.Usage.OutputTokens,
CacheReadTokens: msg.Usage.CacheReadInputTokens,
CacheWriteTokens: msg.Usage.CacheCreationInputTokens,
},
}
}
func claudeUsageHasTokens(input, output, cacheRead, cacheWrite int64) bool {
return input > 0 || output > 0 || cacheRead > 0 || cacheWrite > 0
}
type claudeContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`

View File

@@ -595,6 +595,52 @@ func TestClaudeExecuteSurfacesStderrWhenChildExitsEarly(t *testing.T) {
}
}
func TestClaudeExecuteRecordsResultModelUsage(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("shell-script fixture is POSIX-only")
}
fakePath := filepath.Join(t.TempDir(), "claude")
script := "#!/bin/sh\n" +
"cat >/dev/null\n" +
"printf '%s\\n' '{\"type\":\"system\",\"session_id\":\"sess-result-usage\"}'\n" +
"printf '%s\\n' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"session_id\":\"sess-result-usage\",\"result\":\"done\",\"modelUsage\":{\"zhipu/coding-plan\":{\"inputTokens\":123,\"outputTokens\":45,\"cacheReadInputTokens\":7,\"cacheCreationInputTokens\":11,\"costUSD\":0.01}}}'\n"
writeTestExecutable(t, fakePath, []byte(script))
backend, err := New("claude", Config{ExecutablePath: fakePath, Logger: slog.Default()})
if err != nil {
t.Fatalf("new claude backend: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
session, err := backend.Execute(ctx, "prompt-ignored", ExecOptions{Timeout: 5 * time.Second})
if err != nil {
t.Fatalf("execute: %v", err)
}
go func() {
for range session.Messages {
}
}()
select {
case result, ok := <-session.Result:
if !ok {
t.Fatal("result channel closed without a value")
}
usage, ok := result.Usage["zhipu/coding-plan"]
if !ok {
t.Fatalf("expected usage for zhipu/coding-plan, got %#v", result.Usage)
}
if usage.InputTokens != 123 || usage.OutputTokens != 45 || usage.CacheReadTokens != 7 || usage.CacheWriteTokens != 11 {
t.Fatalf("unexpected usage: %+v", usage)
}
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for result")
}
}
func mustMarshal(t *testing.T, v any) json.RawMessage {
t.Helper()
data, err := json.Marshal(v)