mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
model: Load tensors behind an interface
Currently, if a model uses an interface for its data structures (as mllama does) then the tensor data in the structs implementing that interface will not get loaded.
This commit is contained in:
parent
d223f3b697
commit
d650ad398f
@ -147,15 +147,12 @@ func New(s string) (Model, error) {
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(b, v))
|
||||
v.Elem().Set(populateFields(b, v.Elem()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t, v = t.Elem(), v.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
allNil := true
|
||||
@ -205,18 +202,16 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer {
|
||||
vvv := vv.Elem()
|
||||
if vv.IsNil() {
|
||||
vvv = reflect.New(tt.Elem())
|
||||
}
|
||||
|
||||
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
|
||||
vv.Set(f.Addr())
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
setPointer(b, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -233,6 +228,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
return v
|
||||
}
|
||||
|
||||
func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
|
||||
vv := v
|
||||
if v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if f := populateFields(b, vv, tags...); f.CanAddr() {
|
||||
v.Set(f.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string
|
||||
Alternate []string
|
||||
|
@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
|
||||
"output_norm.weight",
|
||||
"output.weight",
|
||||
},
|
||||
}, v))
|
||||
}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
@ -125,7 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
names: []string{
|
||||
"input.weight",
|
||||
},
|
||||
}, v))
|
||||
}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
|
Loading…
x
Reference in New Issue
Block a user