diff --git a/ml/backend.go b/ml/backend.go index 1ffd2f631..855ca245e 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "os" + "strconv" "strings" ) @@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string { switch t.DType() { case DTypeF32: - return dump[[]float32](t, opts[0]) + return dump[[]float32](t, opts[0].Items, func(f float32) string { + return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) + }) case DTypeI32: - return dump[[]int32](t, opts[0]) + return dump[[]int32](t, opts[0].Items, func(i int32) string { + return strconv.FormatInt(int64(i), 10) + }) default: return "" } } -func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string { +func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string { bts := t.Bytes() if bts == nil { return "" @@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string { fmt.Fprint(&sb, "[") defer func() { fmt.Fprint(&sb, "]") }() for i := int64(0); i < dims[0]; i++ { - if i >= opts.Items && i < dims[0]-opts.Items { + if i >= items && i < dims[0]-items { fmt.Fprint(&sb, "..., ") // skip to next printable element - skip := dims[0] - 2*opts.Items + skip := dims[0] - 2*items if len(dims) > 1 { stride += mul(append(dims[1:], skip)...) fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix) @@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string { fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) } } else { - fmt.Fprint(&sb, s[stride+i]) + fmt.Fprint(&sb, fn(s[stride+i])) if i < dims[0]-1 { fmt.Fprint(&sb, ", ") }