mirror of
https://github.com/ollama/ollama.git
synced 2025-12-14 06:02:50 +01:00
fix tensor merge (#13053)
This commit is contained in:
@@ -2,10 +2,12 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
"path"
|
"path"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
|
|||||||
return matched
|
return matched
|
||||||
})
|
})
|
||||||
|
|
||||||
|
slices.SortStableFunc(matched, func(a, b Tensor) int {
|
||||||
|
x := strings.Split(a.Name(), ".")
|
||||||
|
y := strings.Split(b.Name(), ".")
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return cmp.Compare(len(x), len(y))
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := make([]int, len(x))
|
||||||
|
for i := range x {
|
||||||
|
vals[i] = strings.Compare(x[i], y[i])
|
||||||
|
m, err := strconv.ParseInt(x[i], 0, 0)
|
||||||
|
n, err2 := strconv.ParseInt(y[i], 0, 0)
|
||||||
|
if errors.Join(err, err2) == nil {
|
||||||
|
vals[i] = cmp.Compare(m, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmp.Or(vals...)
|
||||||
|
})
|
||||||
|
|
||||||
if len(matched) > 0 {
|
if len(matched) > 0 {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: merges[i].name,
|
Name: merges[i].name,
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package convert
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMergeOrder(t *testing.T) {
|
||||||
|
for range 8 {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
tensors := make([]Tensor, 16)
|
||||||
|
for i := range tensors {
|
||||||
|
tensors[i] = &fakeTensor{
|
||||||
|
name: fmt.Sprintf("layer.%d.weight", i),
|
||||||
|
shape: []uint64{1},
|
||||||
|
data: []float32{float32(i)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Shuffle(len(tensors), func(i, j int) {
|
||||||
|
tensors[i], tensors[j] = tensors[j], tensors[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
|
||||||
|
if len(unmatched) != 0 {
|
||||||
|
t.Error("expected no remaining tensors, got", len(unmatched))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matched) != 1 {
|
||||||
|
t.Error("expected 1 merged tensor, got", len(matched))
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := matched[0].WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s [16]float32
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.IsSorted(f32s[:]) {
|
||||||
|
t.Errorf("merged tensor data is not in order: %+v", f32s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user