mirror of
https://github.com/ollama/ollama.git
synced 2025-08-26 19:31:28 +02:00
feat: qwen3 dense and sparse models (#10708)
* feat: qwen3 dense * feat: qwen3moe * fix llama4 moe
This commit is contained in:
@@ -128,6 +128,8 @@ type Tensor interface {
|
|||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
|
Div(ctx Context, t2 Tensor) Tensor
|
||||||
|
|
||||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||||
@@ -136,6 +138,7 @@ type Tensor interface {
|
|||||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
SumRows(ctx Context) Tensor
|
||||||
|
|
||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
@@ -887,6 +887,13 @@ func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@@ -1004,6 +1011,13 @@ func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
@@ -82,7 +82,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||||||
|
|
||||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
|
nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nextStates
|
return nextStates
|
||||||
|
@@ -8,4 +8,5 @@ import (
|
|||||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||||
|
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||||
)
|
)
|
||||||
|
239
model/models/qwen3/model.go
Normal file
239
model/models/qwen3/model.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
package qwen3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
eps float32
|
||||||
|
ropeBase, ropeScale float32
|
||||||
|
|
||||||
|
keyLength, valueLength int
|
||||||
|
|
||||||
|
numExperts, numExpertsUsed int
|
||||||
|
normTopKProb bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) headDim() int {
|
||||||
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenStates)
|
||||||
|
key := sa.Key.Forward(ctx, hiddenStates)
|
||||||
|
value := sa.Value.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
|
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
|
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP interface {
|
||||||
|
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type sparse struct {
|
||||||
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
|
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||||
|
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||||
|
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||||
|
routerLogits := mlp.Router.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
routingWeights := routerLogits.Softmax(ctx)
|
||||||
|
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
|
||||||
|
if opts.normTopKProb {
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||||
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
|
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
|
||||||
|
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
hiddenStates = hiddenStates.SILU(ctx)
|
||||||
|
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||||
|
|
||||||
|
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|
||||||
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
|
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type dense struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||||
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
*Attention
|
||||||
|
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||||
|
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
|
||||||
|
residual = hiddenStates
|
||||||
|
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = d.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
return hiddenStates.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
|
||||||
|
*Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Model = (*Model)(nil)
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
for i := range layers {
|
||||||
|
if c.String("general.architecture") == "qwen3moe" {
|
||||||
|
layers[i].MLP = &sparse{}
|
||||||
|
} else {
|
||||||
|
layers[i].MLP = &dense{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Layers: layers,
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("qwen3", New)
|
||||||
|
model.Register("qwen3moe", New)
|
||||||
|
}
|
Reference in New Issue
Block a user