model: Load tensors behind an interface

Currently, if a model uses an interface for its data structures (as mllama
does) then the tensor data in the structs implementing that interface will
not get loaded.
This commit is contained in:
Jesse Gross 2025-01-14 16:12:14 -08:00 committed by Jesse Gross
parent d223f3b697
commit d650ad398f
2 changed files with 31 additions and 16 deletions

View File

@ -147,15 +147,12 @@ func New(s string) (Model, error) {
}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(b, v))
v.Elem().Set(populateFields(b, v.Elem()))
return m, nil
}
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Pointer {
t, v = t.Elem(), v.Elem()
}
if t.Kind() == reflect.Struct {
allNil := true
@ -205,18 +202,16 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
break
}
}
} else if tt.Kind() == reflect.Pointer {
vvv := vv.Elem()
if vv.IsNil() {
vvv = reflect.New(tt.Elem())
}
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
vv.Set(f.Addr())
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(b, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
}
}
}
@ -233,6 +228,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
return v
}
func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
return
}
vv = vv.Elem()
}
vv = vv.Elem()
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
if f := populateFields(b, vv, tags...); f.CanAddr() {
v.Set(f.Addr())
}
}
type Tag struct {
Name string
Alternate []string

View File

@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
"output_norm.weight",
"output.weight",
},
}, v))
}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
@ -125,7 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
names: []string{
"input.weight",
},
}, v))
}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},