Files
ollama/fs/ggml/gguf_test.go
2025-08-28 17:02:59 -07:00

130 lines
3.4 KiB
Go

package ggml
import (
"bytes"
"math/rand/v2"
"os"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
b := bytes.NewBuffer(make([]byte, 2*3))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
}
rand.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 592,
items: []*Tensor{
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
}
}
func BenchmarkReadArray(b *testing.B) {
b.ReportAllocs()
create := func(tb testing.TB, kv KV) string {
tb.Helper()
f, err := os.CreateTemp(b.TempDir(), "")
if err != nil {
b.Fatal(err)
}
defer f.Close()
if err := WriteGGUF(f, kv, nil); err != nil {
b.Fatal(err)
}
return f.Name()
}
cases := map[string]any{
"int32": slices.Repeat([]int32{42}, 1_000_000),
"uint32": slices.Repeat([]uint32{42}, 1_000_000),
"float32": slices.Repeat([]float32{42.}, 1_000_000),
"string": slices.Repeat([]string{"42"}, 1_000_000),
}
for name, bb := range cases {
for _, maxArraySize := range []int{-1, 0, 1024} {
b.Run(name+"-maxArraySize="+strconv.Itoa(maxArraySize), func(b *testing.B) {
p := create(b, KV{"array": bb})
for b.Loop() {
f, err := os.Open(p)
if err != nil {
b.Fatal(err)
}
if _, err := Decode(f, maxArraySize); err != nil {
b.Fatal(err)
}
f.Close()
}
})
}
}
}