diff --git a/convert/tensor.go b/convert/tensor.go index 27bdd13ff1..68870744f4 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -2,10 +2,12 @@ package convert import ( "cmp" + "errors" "io" "iter" "path" "slices" + "strconv" "strings" "github.com/pdevine/tensor" @@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ [] return matched }) + slices.SortStableFunc(matched, func(a, b Tensor) int { + x := strings.Split(a.Name(), ".") + y := strings.Split(b.Name(), ".") + if len(x) != len(y) { + return cmp.Compare(len(x), len(y)) + } + + vals := make([]int, len(x)) + for i := range x { + vals[i] = strings.Compare(x[i], y[i]) + m, err := strconv.ParseInt(x[i], 0, 0) + n, err2 := strconv.ParseInt(y[i], 0, 0) + if errors.Join(err, err2) == nil { + vals[i] = cmp.Compare(m, n) + } + } + + return cmp.Or(vals...) + }) + if len(matched) > 0 { out = append(out, &ggml.Tensor{ Name: merges[i].name, diff --git a/convert/tensor_test.go b/convert/tensor_test.go index c1f58da6e4..e0dc2350a6 100644 --- a/convert/tensor_test.go +++ b/convert/tensor_test.go @@ -3,8 +3,10 @@ package convert import ( "bytes" "encoding/binary" + "fmt" "io" "iter" + "math/rand/v2" "slices" "strings" "testing" @@ -951,3 +953,45 @@ func TestMerge(t *testing.T) { } }) } + +func TestMergeOrder(t *testing.T) { + for range 8 { + t.Run("", func(t *testing.T) { + tensors := make([]Tensor, 16) + for i := range tensors { + tensors[i] = &fakeTensor{ + name: fmt.Sprintf("layer.%d.weight", i), + shape: []uint64{1}, + data: []float32{float32(i)}, + } + } + + rand.Shuffle(len(tensors), func(i, j int) { + tensors[i], tensors[j] = tensors[j], tensors[i] + }) + + matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"}) + if len(unmatched) != 0 { + t.Error("expected no remaining tensors, got", len(unmatched)) + } + + if len(matched) != 1 { + t.Error("expected 1 merged tensor, got", len(matched)) + } + + var b bytes.Buffer + if _, err := matched[0].WriteTo(&b); err != nil { + t.Fatal(err) + } + + var f32s [16]float32 + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + if !slices.IsSorted(f32s[:]) { + t.Errorf("merged tensor data is not in order: %+v", f32s) + } + }) + } +}