mirror of
https://github.com/ollama/ollama.git
synced 2025-09-25 18:20:53 +02:00
Runner for Ollama engine
This provides integration with the new Ollama engine
(5824541
next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.
In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
- Parallel processing
- Memory management for defragmentation and shifting
- Multi-modal modals
Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:
Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve
Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
This commit is contained in:
128
model/model.go
128
model/model.go
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
@@ -15,102 +16,42 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
cache.Cache
|
||||
cache.Options
|
||||
}
|
||||
|
||||
func (c Cache) Sub(i int) Cache {
|
||||
if c.Cache != nil {
|
||||
return Cache{
|
||||
Cache: c.Cache.Sub(i),
|
||||
Options: c.Options,
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
|
||||
if c.Cache != nil {
|
||||
return c.Cache.Put(ctx, key, value, opts)
|
||||
}
|
||||
|
||||
return key, value
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
inputs []int32
|
||||
|
||||
Offset int
|
||||
Inputs []int32
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
|
||||
Images []image.Image
|
||||
|
||||
Cache
|
||||
}
|
||||
|
||||
func (opts Options) Inputs() []int32 {
|
||||
return opts.inputs[opts.Offset:]
|
||||
}
|
||||
|
||||
func (opts Options) Positions() []int32 {
|
||||
positions := make([]int32, len(opts.inputs)-opts.Offset)
|
||||
for i := range positions {
|
||||
positions[i] = int32(opts.Offset + i)
|
||||
}
|
||||
|
||||
return positions
|
||||
}
|
||||
|
||||
type OptionsFunc func(Model, *Options)
|
||||
|
||||
func WithInputIDs(ids []int32) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.inputs = ids
|
||||
}
|
||||
}
|
||||
|
||||
func WithOffset(offset int) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Offset = offset
|
||||
opts.Cache.Position = offset
|
||||
}
|
||||
}
|
||||
|
||||
func WithImage(img image.Image) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Images = append(opts.Images, img)
|
||||
}
|
||||
}
|
||||
|
||||
func WithCache(c cache.Cache) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Cache = Cache{
|
||||
Cache: c,
|
||||
Options: cache.Options{
|
||||
Position: opts.Offset,
|
||||
},
|
||||
}
|
||||
}
|
||||
type config struct {
|
||||
Cache kvcache.Cache
|
||||
}
|
||||
|
||||
type Base struct {
|
||||
b ml.Backend
|
||||
config
|
||||
}
|
||||
|
||||
func (m *Base) Backend() ml.Backend {
|
||||
return m.b
|
||||
}
|
||||
|
||||
func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Forward(ml.Context, Options) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
}
|
||||
|
||||
var models = make(map[string]func(ml.Config) (Model, error))
|
||||
@@ -146,12 +87,14 @@ func New(s string) (Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(b, v.Elem()))
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
@@ -170,7 +113,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(Base{b: b}))
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag) [][]string
|
||||
fn = func(tags []Tag) (values [][]string) {
|
||||
@@ -196,21 +139,21 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
|
||||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
slog.Debug("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
setPointer(b, vv, tagsCopy)
|
||||
setPointer(base, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,7 +171,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
return v
|
||||
}
|
||||
|
||||
func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
|
||||
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv := v
|
||||
if v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
@@ -243,7 +186,7 @@ func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if f := populateFields(b, vv, tags...); f.CanAddr() {
|
||||
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
||||
v.Set(f.Addr())
|
||||
}
|
||||
}
|
||||
@@ -277,18 +220,27 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
|
||||
var opts Options
|
||||
for _, optsFunc := range optsFuncs {
|
||||
optsFunc(m, &opts)
|
||||
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||
if len(opts.Positions) != len(opts.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||
}
|
||||
|
||||
if len(opts.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
t, err := m.Forward(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ctx.Close()
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
|
@@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) {
|
||||
|
||||
var m fakeModel
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
@@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
|
||||
"output_norm.weight",
|
||||
"output.weight",
|
||||
},
|
||||
}, v.Elem()))
|
||||
}}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
@@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
|
||||
m := fakeModel{}
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
},
|
||||
}, v.Elem()))
|
||||
}}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
|
@@ -3,6 +3,7 @@ package llama
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -28,7 +29,7 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
@@ -49,7 +50,11 @@ func New(c ml.Config) (model.Model, error) {
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
@@ -59,7 +64,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@@ -74,7 +79,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
k, v = cache.Put(ctx, k, v, cache.Options)
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
@@ -82,6 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
@@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
@@ -109,7 +120,7 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
|
||||
m.Cache.SetLayer(i)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -18,8 +19,13 @@ type Model struct {
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
const (
|
||||
crossAttentionLayer = iota
|
||||
selfAttentionLayer
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
@@ -33,7 +39,11 @@ func New(c ml.Config) (model.Model, error) {
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
@@ -73,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
|
||||
}
|
||||
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -4,9 +4,9 @@ import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type TextSelfAttention struct {
|
||||
@@ -16,7 +16,7 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@@ -31,7 +31,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
key, value = cache.Put(ctx, key, value, cache.Options)
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
@@ -39,11 +40,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
||||
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
|
||||
if mask != nil {
|
||||
scores = scores.Add(ctx, mask)
|
||||
}
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
@@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
@@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct {
|
||||
MLP *TextMLP
|
||||
}
|
||||
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -94,23 +96,29 @@ type TextCrossAttention struct {
|
||||
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
||||
}
|
||||
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
query := ca.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
key := ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
value := ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// TODO cache key, value
|
||||
value = ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
@@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct {
|
||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||
}
|
||||
|
||||
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
||||
}
|
||||
|
||||
type TextDecoderLayer interface {
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextDecoder struct {
|
||||
Layers []TextDecoderLayer
|
||||
}
|
||||
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
for i, layer := range d.Layers {
|
||||
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
|
||||
layerType := selfAttentionLayer
|
||||
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
||||
layerType = crossAttentionLayer
|
||||
}
|
||||
|
||||
cache.SetLayer(i)
|
||||
cache.SetLayerType(layerType)
|
||||
|
||||
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +205,7 @@ type TextModel struct {
|
||||
*TextModelOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
Reference in New Issue
Block a user