2025-02-14 00:31:21 +00:00
|
|
|
package model
|
|
|
|
|
|
|
|
import (
|
2024-12-17 19:59:41 -08:00
|
|
|
"errors"
|
2025-02-14 00:31:21 +00:00
|
|
|
"fmt"
|
|
|
|
"image"
|
|
|
|
_ "image/jpeg"
|
|
|
|
_ "image/png"
|
|
|
|
"log/slog"
|
|
|
|
"os"
|
|
|
|
"reflect"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
_ "golang.org/x/image/bmp"
|
|
|
|
_ "golang.org/x/image/tiff"
|
|
|
|
_ "golang.org/x/image/webp"
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
"github.com/ollama/ollama/kvcache"
|
2025-02-14 00:31:21 +00:00
|
|
|
"github.com/ollama/ollama/ml"
|
|
|
|
_ "github.com/ollama/ollama/ml/backend"
|
|
|
|
)
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// Options contains the inputs for a model forward pass
|
2025-02-14 00:31:21 +00:00
|
|
|
type Options struct {
|
2024-12-17 19:59:41 -08:00
|
|
|
Inputs []int32
|
|
|
|
Positions []int32
|
|
|
|
Sequences []int
|
|
|
|
Outputs []int32
|
2025-02-14 00:31:21 +00:00
|
|
|
|
|
|
|
Images []image.Image
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
type config struct {
|
|
|
|
Cache kvcache.Cache
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// Base implements the common fields and methods for all models
|
2025-02-14 00:31:21 +00:00
|
|
|
type Base struct {
|
|
|
|
b ml.Backend
|
2024-12-17 19:59:41 -08:00
|
|
|
config
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// Backend returns the underlying backend that will run the model
|
2025-02-14 00:31:21 +00:00
|
|
|
func (m *Base) Backend() ml.Backend {
|
|
|
|
return m.b
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
func (m *Base) Config() config {
|
|
|
|
return m.config
|
|
|
|
}
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
2025-02-14 00:31:21 +00:00
|
|
|
type Model interface {
|
|
|
|
Forward(ml.Context, Options) (ml.Tensor, error)
|
|
|
|
|
|
|
|
Backend() ml.Backend
|
2024-12-17 19:59:41 -08:00
|
|
|
Config() config
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var models = make(map[string]func(ml.Config) (Model, error))
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// Register registers a model constructor for the given architecture
|
2025-02-14 00:31:21 +00:00
|
|
|
func Register(name string, f func(ml.Config) (Model, error)) {
|
|
|
|
if _, ok := models[name]; ok {
|
|
|
|
panic("model: model already registered")
|
|
|
|
}
|
|
|
|
|
|
|
|
models[name] = f
|
|
|
|
}
|
|
|
|
|
2025-02-14 16:01:00 -08:00
|
|
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
2025-02-20 11:18:01 -08:00
|
|
|
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
2025-02-14 16:01:00 -08:00
|
|
|
r, err := os.Open(modelPath)
|
2025-02-14 00:31:21 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
defer r.Close()
|
|
|
|
|
2025-02-20 11:18:01 -08:00
|
|
|
b, err := ml.NewBackend(r, params)
|
2025-02-14 00:31:21 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
arch := b.Config().Architecture()
|
|
|
|
f, ok := models[arch]
|
|
|
|
if !ok {
|
|
|
|
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
|
|
}
|
|
|
|
|
|
|
|
m, err := f(b.Config())
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
base := Base{b: b, config: m.Config()}
|
|
|
|
|
2025-02-14 00:31:21 +00:00
|
|
|
v := reflect.ValueOf(m)
|
2024-12-17 19:59:41 -08:00
|
|
|
v.Elem().Set(populateFields(base, v.Elem()))
|
2025-02-14 00:31:21 +00:00
|
|
|
return m, nil
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
2025-02-14 00:31:21 +00:00
|
|
|
t := v.Type()
|
|
|
|
|
|
|
|
if t.Kind() == reflect.Struct {
|
|
|
|
allNil := true
|
|
|
|
for i := range t.NumField() {
|
|
|
|
tt := t.Field(i).Type
|
|
|
|
vv := v.Field(i)
|
|
|
|
if !vv.CanSet() {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
// make a copy
|
|
|
|
tagsCopy := tags
|
|
|
|
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
|
|
|
tagsCopy = append(tagsCopy, ParseTags(tag))
|
|
|
|
}
|
|
|
|
|
|
|
|
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
2024-12-17 19:59:41 -08:00
|
|
|
vv.Set(reflect.ValueOf(base))
|
2025-02-14 00:31:21 +00:00
|
|
|
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
|
|
|
var fn func([]Tag) [][]string
|
|
|
|
fn = func(tags []Tag) (values [][]string) {
|
|
|
|
if len(tags) < 1 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
values = [][]string{{tags[0].Name}}
|
|
|
|
for _, alt := range tags[0].Alternate {
|
|
|
|
values = append(values, []string{alt})
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, value := range values {
|
|
|
|
for _, rest := range fn(tags[1:]) {
|
|
|
|
value = append(value, rest...)
|
|
|
|
}
|
|
|
|
|
|
|
|
values[i] = value
|
|
|
|
}
|
|
|
|
|
|
|
|
return values
|
|
|
|
}
|
|
|
|
|
|
|
|
names := fn(tagsCopy)
|
|
|
|
for _, name := range names {
|
2024-12-17 19:59:41 -08:00
|
|
|
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
2025-02-14 00:31:21 +00:00
|
|
|
slog.Debug("found tensor", "", tensor)
|
|
|
|
vv.Set(reflect.ValueOf(tensor))
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
2025-01-14 16:12:14 -08:00
|
|
|
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
2024-12-17 19:59:41 -08:00
|
|
|
setPointer(base, vv, tagsCopy)
|
2025-02-14 00:31:21 +00:00
|
|
|
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
|
|
|
for i := range vv.Len() {
|
2025-01-14 16:12:14 -08:00
|
|
|
vvv := vv.Index(i)
|
|
|
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
2024-12-17 19:59:41 -08:00
|
|
|
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
2025-01-14 16:12:14 -08:00
|
|
|
} else {
|
2024-12-17 19:59:41 -08:00
|
|
|
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
2025-01-14 16:12:14 -08:00
|
|
|
}
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if !canNil(tt) || !vv.IsNil() {
|
|
|
|
allNil = false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if allNil {
|
|
|
|
return reflect.Zero(t)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return v
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
2025-01-14 16:12:14 -08:00
|
|
|
vv := v
|
|
|
|
if v.Kind() == reflect.Interface {
|
|
|
|
if v.IsNil() {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
vv = vv.Elem()
|
|
|
|
}
|
|
|
|
|
|
|
|
vv = vv.Elem()
|
|
|
|
if v.IsNil() {
|
|
|
|
vv = reflect.New(v.Type().Elem()).Elem()
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
2025-01-14 16:12:14 -08:00
|
|
|
v.Set(f.Addr())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-02-14 00:31:21 +00:00
|
|
|
type Tag struct {
|
|
|
|
Name string
|
|
|
|
Alternate []string
|
|
|
|
}
|
|
|
|
|
|
|
|
func ParseTags(s string) (tag Tag) {
|
|
|
|
parts := strings.Split(s, ",")
|
|
|
|
if len(parts) > 0 {
|
|
|
|
tag.Name = parts[0]
|
|
|
|
|
|
|
|
for _, part := range parts[1:] {
|
|
|
|
if value, ok := strings.CutPrefix(part, "alt:"); ok {
|
|
|
|
tag.Alternate = append(tag.Alternate, value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func canNil(t reflect.Type) bool {
|
|
|
|
return t.Kind() == reflect.Chan ||
|
|
|
|
t.Kind() == reflect.Func ||
|
|
|
|
t.Kind() == reflect.Interface ||
|
|
|
|
t.Kind() == reflect.Map ||
|
|
|
|
t.Kind() == reflect.Pointer ||
|
|
|
|
t.Kind() == reflect.Slice
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
|
|
|
if len(opts.Positions) != len(opts.Sequences) {
|
|
|
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(opts.Positions) < 1 {
|
|
|
|
return nil, errors.New("batch size cannot be less than 1")
|
|
|
|
}
|
|
|
|
|
|
|
|
cache := m.Config().Cache
|
|
|
|
if cache != nil {
|
|
|
|
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
t, err := m.Forward(ctx, opts)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2025-02-21 11:57:08 -08:00
|
|
|
ctx.Forward(t).Compute(t)
|
2025-02-03 19:35:12 -08:00
|
|
|
|
|
|
|
return t, nil
|
2025-02-14 00:31:21 +00:00
|
|
|
}
|