mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 15:57:58 +01:00
130 lines
3.4 KiB
Go
130 lines
3.4 KiB
Go
package ggml
|
|
|
|
import (
|
|
"bytes"
|
|
"math/rand/v2"
|
|
"os"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
func TestWriteGGUF(t *testing.T) {
|
|
b := bytes.NewBuffer(make([]byte, 2*3))
|
|
for range 8 {
|
|
t.Run("shuffle", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ts := []*Tensor{
|
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
|
}
|
|
|
|
rand.Shuffle(len(ts), func(i, j int) {
|
|
ts[i], ts[j] = ts[j], ts[i]
|
|
})
|
|
|
|
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer w.Close()
|
|
|
|
if err := WriteGGUF(w, KV{
|
|
"general.alignment": uint32(16),
|
|
}, ts); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
r, err := os.Open(w.Name())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer r.Close()
|
|
|
|
ff, err := Decode(r, 0)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if diff := cmp.Diff(KV{
|
|
"general.alignment": uint32(16),
|
|
"general.parameter_count": uint64(54),
|
|
}, ff.KV()); diff != "" {
|
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
|
}
|
|
|
|
if diff := cmp.Diff(Tensors{
|
|
Offset: 592,
|
|
items: []*Tensor{
|
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
|
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
|
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
|
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
|
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
|
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
|
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
|
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
|
},
|
|
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func BenchmarkReadArray(b *testing.B) {
|
|
b.ReportAllocs()
|
|
|
|
create := func(tb testing.TB, kv KV) string {
|
|
tb.Helper()
|
|
f, err := os.CreateTemp(b.TempDir(), "")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := WriteGGUF(f, kv, nil); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
|
|
return f.Name()
|
|
}
|
|
|
|
cases := map[string]any{
|
|
"int32": slices.Repeat([]int32{42}, 1_000_000),
|
|
"uint32": slices.Repeat([]uint32{42}, 1_000_000),
|
|
"float32": slices.Repeat([]float32{42.}, 1_000_000),
|
|
"string": slices.Repeat([]string{"42"}, 1_000_000),
|
|
}
|
|
|
|
for name, bb := range cases {
|
|
for _, maxArraySize := range []int{-1, 0, 1024} {
|
|
b.Run(name+"-maxArraySize="+strconv.Itoa(maxArraySize), func(b *testing.B) {
|
|
p := create(b, KV{"array": bb})
|
|
for b.Loop() {
|
|
f, err := os.Open(p)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
if _, err := Decode(f, maxArraySize); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
f.Close()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|