mirror of
synced 2025-03-18 22:01:47 +01:00
update Context.Forward to accept multiple tensors to match Context.Compute signature update Context.Forward to return Context such that it can be chained with Context.Compute
255 lines
5.3 KiB
255 lines
5.3 KiB
package model
import (
_ "image/jpeg"
_ "image/png"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
_ "github.com/ollama/ollama/ml/backend"
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Positions []int32
Sequences []int
Outputs []int32
Images []image.Image
type config struct {
Cache kvcache.Cache
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
func (m *Base) Config() config {
return m.config
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(ml.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
models[name] = f
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
defer r.Close()
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
arch := b.Config().Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
m, err := f(b.Config())
if err != nil {
return nil, err
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, ParseTags(tag))
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string
fn = func(tags []Tag) (values [][]string) {
if len(tags) < 1 {
return nil
values = [][]string{{tags[0].Name}}
for _, alt := range tags[0].Alternate {
values = append(values, []string{alt})
for i, value := range values {
for _, rest := range fn(tags[1:]) {
value = append(value, rest...)
values[i] = value
return values
names := fn(tagsCopy)
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
slog.Debug("found tensor", "", tensor)
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
if !canNil(tt) || !vv.IsNil() {
allNil = false
if allNil {
return reflect.Zero(t)
return v
func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
vv = vv.Elem()
vv = vv.Elem()
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
if f := populateFields(base, vv, tags...); f.CanAddr() {
type Tag struct {
Name string
Alternate []string
func ParseTags(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.Name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok {
tag.Alternate = append(tag.Alternate, value)
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
if len(opts.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
if err != nil {
return nil, err
t, err := m.Forward(ctx, opts)
if err != nil {
return nil, err
return t, nil