From c890011322fbdd325ef9f16e425fe1f5213a24fe Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 21 May 2025 10:21:24 -0700 Subject: [PATCH] feat: port qwen2 model (#10782) --- model/models/llama/model.go | 47 +++++----- model/models/models.go | 1 + model/models/qwen2/model.go | 170 ++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 model/models/qwen2/model.go diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 7b475512b..507f1ebc2 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -75,30 +75,31 @@ type SelfAttention struct { RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) - q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query := sa.Query.Forward(ctx, hiddenState) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key := sa.Key.Forward(ctx, hiddenState) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - return sa.Output.Forward(ctx, kqv) + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, attention) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -119,11 +120,11 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts) // In the final layer (outputs != nil), optimize by pruning to just the token positions // we need logits for. @@ -146,22 +147,20 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } - hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) - var lastLayerOutputs ml.Tensor + var outputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } } - hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/models.go b/model/models/models.go index fd935f30c..5471ce89a 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,6 +7,7 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" ) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go new file mode 100644 index 000000000..3c3d81aa5 --- /dev/null +++ b/model/models/qwen2/model.go @@ -0,0 +1,170 @@ +package qwen2 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) + + query := attn.Query.Forward(ctx, hiddenStates) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key := attn.Key.Forward(ctx, hiddenStates) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + value := attn.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + + return attn.Output.Forward(ctx, attention) +} + +type MLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type DecoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.MLP.Forward(ctx, hiddenStates) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []DecoderLayer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +// Forward implements model.Model. +func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + hiddenStates = m.Output.Forward(ctx, hiddenStates) + return hiddenStates, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil +} + +func New(c fs.Config) (model.Model, error) { + m := Model{ + Layers: make([]DecoderLayer, c.Uint("block_count")), + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + eps: c.Float("attention.layer_norm_rms_epsilon"), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func init() { + model.Register("qwen2", New) +}