From bf78ed6ee94e593a7edae2e277a736379cbc2413 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 16:08:57 -0700 Subject: [PATCH] add pre:, suf: to tags (#12274) --- model/model.go | 67 ++++++++++++++++++++----------- model/model_test.go | 61 +++++++++++++++++++++++++--- model/models/llama4/model_text.go | 14 +------ 3 files changed, 101 insertions(+), 41 deletions(-) diff --git a/model/model.go b/model/model.go index f3d6bb3db2..2b6ad73172 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,7 @@ import ( "fmt" _ "image/jpeg" _ "image/png" + "log/slog" "os" "reflect" "strconv" @@ -171,35 +172,42 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { // make a copy tagsCopy := tags if tag := t.Field(i).Tag.Get("gguf"); tag != "" { - tagsCopy = append(tagsCopy, ParseTags(tag)) + tagsCopy = append(tagsCopy, parseTag(tag)) } if tt == reflect.TypeOf((*Base)(nil)).Elem() { vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { - var fn func([]Tag) [][]string - fn = func(tags []Tag) (names [][]string) { + var fn func([]Tag, string, string) [][]string + fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { if len(tags) > 0 { - localNames := []string{tags[0].Name} - localNames = append(localNames, tags[0].Alternate...) + var names []string + if tags[0].name != "" { + for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { + names = append(names, prefix+n+suffix) + } + } - for _, localName := range localNames { - fullName := []string{localName} - nested := fn(tags[1:]) - if len(nested) > 0 { - for _, rest := range nested { - names = append(names, append(fullName, rest...)) + if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 { + // no child names, append current names + fullNames = append(fullNames, names) + } else if len(names) == 0 { + // no current names, append child names + fullNames = append(fullNames, childNames...) + } else { + // combine current and child names + for _, name := range names { + for _, childName := range childNames { + fullNames = append(fullNames, append([]string{name}, childName...)) } - } else { - names = append(names, fullName) } } } - return names + return fullNames } - names := fn(tagsCopy) + names := fn(tagsCopy, "", "") for _, name := range names { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { logutil.Trace("found tensor", "", tensor) @@ -213,9 +221,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) } } } @@ -254,18 +262,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { } type Tag struct { - Name string - Alternate []string + name, + // prefix and suffix are applied to child tags + prefix, + suffix string + alternatives []string } -func ParseTags(s string) (tag Tag) { +func parseTag(s string) (tag Tag) { parts := strings.Split(s, ",") if len(parts) > 0 { - tag.Name = parts[0] + tag.name = parts[0] for _, part := range parts[1:] { - if value, ok := strings.CutPrefix(part, "alt:"); ok { - tag.Alternate = append(tag.Alternate, value) + if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { + // elevate alternative to primary if no primary given + tag.name = value + slog.Warn("gguf tag has alt: but no primary name", "tag", s) + } else if ok { + tag.alternatives = append(tag.alternatives, value) + } + if value, ok := strings.CutPrefix(part, "pre:"); ok { + tag.prefix = value + } + if value, ok := strings.CutPrefix(part, "suf:"); ok { + tag.suffix = value } } } diff --git a/model/model_test.go b/model/model_test.go index 01080ffdf5..e47278540d 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) { { value: "output", want: Tag{ - Name: "output", + name: "output", }, }, { value: "output,alt:token_embd", want: Tag{ - Name: "output", - Alternate: []string{ + name: "output", + alternatives: []string{ "token_embd", }, }, @@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) { for _, tt := range cases { t.Run(tt.value, func(t *testing.T) { - got := ParseTags(tt.value) - if diff := cmp.Diff(tt.want, got); diff != "" { + got := parseTag(tt.value) + if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" { t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) } }) @@ -147,6 +147,57 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } +func TestPopulateFieldsPrefixSuffixName(t *testing.T) { + type fakeBlock struct { + A *nn.Linear `gguf:"a"` + B *nn.Linear `gguf:",pre:b_"` + C *nn.Linear `gguf:",suf:_c"` + XY *nn.Linear `gguf:",pre:x_,suf:_y"` + } + + type fakeModel struct { + Blocks []fakeBlock `gguf:"blk"` + } + + m := fakeModel{ + Blocks: make([]fakeBlock, 2), + } + v := reflect.ValueOf(&m) + v.Elem().Set(populateFields(Base{b: &fakeBackend{ + names: []string{ + "blk.0.a.weight", + "blk.0.b_weight", + "blk.0.b_bias", + "blk.0.weight_c", + "blk.0.x_weight_y", + "blk.1.a.weight", + "blk.1.b_weight", + "blk.1.b_bias", + "blk.1.weight_c", + "blk.1.x_weight_y", + }, + }}, v.Elem())) + + if diff := cmp.Diff(fakeModel{ + Blocks: []fakeBlock{ + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}}, + }, + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}}, + }, + }, + }, m); diff != "" { + t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) + } +} + func TestModelForArch(t *testing.T) { type fakeModel struct { Model diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index e0f9326000..e056391f56 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens return nextStates } -// TextSharedExpert is TextMLP with different tensor names -type TextSharedExpert struct { - Gate *nn.Linear `gguf:"ffn_gate_shexp"` - Up *nn.Linear `gguf:"ffn_up_shexp"` - Down *nn.Linear `gguf:"ffn_down_shexp"` -} - -func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) - return mlp.Down.Forward(ctx, hiddenStates) -} - type TextMOE struct { Router *nn.Linear `gguf:"ffn_gate_inp"` Experts *TextExperts - SharedExpert *TextSharedExpert + SharedExpert *TextMLP `gguf:",suf:_shexp"` } func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {