diff --git a/model/model.go b/model/model.go index 9290b6d30..4a86d7d60 100644 --- a/model/model.go +++ b/model/model.go @@ -147,15 +147,12 @@ func New(s string) (Model, error) { } v := reflect.ValueOf(m) - v.Elem().Set(populateFields(b, v)) + v.Elem().Set(populateFields(b, v.Elem())) return m, nil } func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() - if t.Kind() == reflect.Pointer { - t, v = t.Elem(), v.Elem() - } if t.Kind() == reflect.Struct { allNil := true @@ -205,18 +202,16 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { break } } - } else if tt.Kind() == reflect.Pointer { - vvv := vv.Elem() - if vv.IsNil() { - vvv = reflect.New(tt.Elem()) - } - - if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() { - vv.Set(f.Addr()) - } + } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { + setPointer(b, vv, tagsCopy) } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { for i := range vv.Len() { - vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv := vv.Index(i) + if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { + setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + } else { + vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + } } } @@ -233,6 +228,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { return v } +func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { + vv := v + if v.Kind() == reflect.Interface { + if v.IsNil() { + return + } + + vv = vv.Elem() + } + + vv = vv.Elem() + if v.IsNil() { + vv = reflect.New(v.Type().Elem()).Elem() + } + + if f := populateFields(b, vv, tags...); f.CanAddr() { + v.Set(f.Addr()) + } +} + type Tag struct { Name string Alternate []string diff --git a/model/model_test.go b/model/model_test.go index 0f6c4a7c7..2ba12acde 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) { "output_norm.weight", "output.weight", }, - }, v)) + }, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, @@ -125,7 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { names: []string{ "input.weight", }, - }, v)) + }, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},