ml: update Dump to handle precision

This commit is contained in:
Michael Yang 2025-02-10 16:50:49 -08:00
parent c4f127ee6d
commit 95eb87a052

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"os"
"strconv"
"strings"
)
@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string {
switch t.DType() {
case DTypeF32:
return dump[[]float32](t, opts[0])
return dump[[]float32](t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeI32:
return dump[[]int32](t, opts[0])
return dump[[]int32](t, opts[0].Items, func(i int32) string {
return strconv.FormatInt(int64(i), 10)
})
default:
return "<unsupported>"
}
}
func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
bts := t.Bytes()
if bts == nil {
return "<nil>"
@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
fmt.Fprint(&sb, "[")
defer func() { fmt.Fprint(&sb, "]") }()
for i := int64(0); i < dims[0]; i++ {
if i >= opts.Items && i < dims[0]-opts.Items {
if i >= items && i < dims[0]-items {
fmt.Fprint(&sb, "..., ")
// skip to next printable element
skip := dims[0] - 2*opts.Items
skip := dims[0] - 2*items
if len(dims) > 1 {
stride += mul(append(dims[1:], skip)...)
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
}
} else {
fmt.Fprint(&sb, s[stride+i])
fmt.Fprint(&sb, fn(s[stride+i]))
if i < dims[0]-1 {
fmt.Fprint(&sb, ", ")
}