From 9ed8bf14cb885509281d63731cda16637a7e0bd2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 20 May 2025 15:51:08 -0700 Subject: [PATCH] ml: add more rope options (#10775) --- ml/backend.go | 16 -------------- ml/backend/ggml/ggml.go | 26 ++++++----------------- ml/nn/fast/rope.go | 21 ++++++++++++++++++ ml/nn/rope/rope.go | 33 +++++++++++++++++++++++++++++ model/models/gemma2/model.go | 9 ++++---- model/models/gemma3/model_text.go | 9 ++++---- model/models/llama/model.go | 17 ++++++++------- model/models/llama4/model_text.go | 8 ++++--- model/models/mistral3/model_text.go | 16 +++++++------- model/models/mllama/model_text.go | 13 ++++++------ model/models/qwen25vl/model_text.go | 31 ++++++++++++++------------- 11 files changed, 116 insertions(+), 83 deletions(-) create mode 100644 ml/nn/fast/rope.go create mode 100644 ml/nn/rope/rope.go diff --git a/ml/backend.go b/ml/backend.go index 2599d774f0..70497e6e36 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -115,21 +115,6 @@ type Context interface { Layer(int) Context } -// RopeOptions contains optional parameters for RoPE function -type RopeOptions struct { - OriginalContextLen uint32 -} - -// RopeOption defines a function that modifies RopeOpts -type RopeOption func(*RopeOptions) - -// WithContextLen sets a custom context length -func WithContextLen(len uint32) RopeOption { - return func(opts *RopeOptions) { - opts.OriginalContextLen = len - } -} - type Tensor interface { Dim(n int) int Stride(n int) int @@ -155,7 +140,6 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Sin(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f0b26b2fe7..3ae50acec4 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "github.com/ollama/ollama/ml/nn/rope" "golang.org/x/sync/errgroup" ) @@ -1074,28 +1075,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } -const ( - ropeTypeNorm C.int = 0 - ropeTypeNeox C.int = 2 - ropeTypeMrope C.int = 8 - ropeTypeVision C.int = 24 -) - -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor { +func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor { // Default options - opts := &ml.RopeOptions{ - OriginalContextLen: 131072, - } + opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}} // Apply any provided options for _, option := range options { option(opts) } - if ropeFactors == nil { - ropeFactors = &Tensor{b: t.b} - } - dequant := t.t if C.ggml_is_quantized(t.t._type) { dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) @@ -1106,11 +1094,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, - positionIDs.(*Tensor).t, - ropeFactors.(*Tensor).t, + positions.(*Tensor).t, + opts.Factors.(*Tensor).t, C.int(ropeDim), - C.int(ropeType), - C.int(opts.OriginalContextLen), + C.int(opts.Type), + C.int(opts.OriginalContextLength), C.float(ropeBase), C.float(ropeScale), C.float(0.0), diff --git a/ml/nn/fast/rope.go b/ml/nn/fast/rope.go new file mode 100644 index 0000000000..b45938ebf3 --- /dev/null +++ b/ml/nn/fast/rope.go @@ -0,0 +1,21 @@ +// fast provides implementations of fast (fused) operations for increased performance. +package fast + +import ( + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/rope" +) + +// fastRoPE is an interface for tensors that support fast rotary positional embedding. +type fastRoPE interface { + RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor +} + +// RoPE applies rotary positional embedding to tensor `t`. +func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor { + if t, ok := t.(fastRoPE); ok { + return t.RoPE(ctx, positions, dim, base, scale, options...) + } + + panic("RoPE not implemented for this tensor type") +} diff --git a/ml/nn/rope/rope.go b/ml/nn/rope/rope.go new file mode 100644 index 0000000000..b0c00a5b95 --- /dev/null +++ b/ml/nn/rope/rope.go @@ -0,0 +1,33 @@ +package rope + +import "github.com/ollama/ollama/ml" + +// Options contains optional parameters for RoPE function +type Options struct { + OriginalContextLength int + Type int + Factors ml.Tensor +} + +// WithOriginalContextLength sets a custom context length +func WithOriginalContextLength(n int) func(*Options) { + return func(opts *Options) { + opts.OriginalContextLength = n + } +} + +// WithType sets RoPE type to NeoX +func WithTypeNeoX() func(*Options) { + return func(opts *Options) { + opts.Type = 2 + } +} + +// WithFactors sets custom rope factors +func WithFactors(factors ml.Tensor) func(*Options) { + return func(opts *Options) { + if factors != nil { + opts.Factors = factors + } + } +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index a87534c54e..3c5a7ea5ba 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) @@ -83,11 +85,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -127,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index a40614af2c..70d7797e96 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,6 +7,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -73,7 +75,6 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) ropeBase := opts.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { @@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9b85f1c703..7b475512bf 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -8,14 +8,16 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type Options struct { - hiddenSize, numHeads, numKVHeads, headDim int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 } type Model struct { @@ -53,10 +55,10 @@ func New(c fs.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } @@ -76,15 +78,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index d98587bd0b..829c805c41 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) - key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -250,5 +252,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 57e2a40a39..19c36f9fe4 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -8,13 +8,14 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - hiddenSize, numHeads, numKVHeads, headDim int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -35,16 +36,15 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(0) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -129,10 +129,10 @@ func newTextModel(c fs.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), }, } } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 9bd414afc5..47a518ceda 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -8,6 +8,8 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" ) type TextSelfAttention struct { @@ -21,15 +23,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int + ropeDim int eps, ropeBase, ropeScale float32 - ropeDim uint32 crossAttentionLayers []int32 } @@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 800fd961d6..4b6bc16661 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -7,13 +7,15 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) type TextOptions struct { - ctxLen, hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim, defaultContextLen uint32 + hiddenSize, numHeads, numKVHeads int + ropeDim, originalContextLength int + eps, ropeBase, ropeScale float32 } type TextModel struct { @@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel { m := TextModel{ Layers: make([]Layer, c.Uint("block_count")), TextOptions: &TextOptions{ - ctxLen: int(c.Uint("context_length")), - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count", 128), - defaultContextLen: c.Uint("context_length", 128000), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + ropeDim: int(c.Uint("rope.dimension_count", 128)), + originalContextLength: int(c.Uint("context_length", 128000)), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), }, } @@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation