From 2e77aa1ae70372388bd4b08b9957e5198d566a22 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 11 Jun 2025 12:10:15 -0700 Subject: [PATCH] use nn.Linear in place of ml.Tensor (#11049) while nn.Linear.Forward isn't applicable for sparse MLP, it's still a nice container for the tensors --- model/models/llama4/model_text.go | 12 ++++++------ model/models/qwen3/model.go | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 27935f4012..045ab403f2 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp } type TextExperts struct { - Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` - Up ml.Tensor `gguf:"ffn_up_exps.weight"` - Down ml.Tensor `gguf:"ffn_down_exps.weight"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) + gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) + downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 1930da7e20..7a83e0d04a 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -66,9 +66,9 @@ type MLP interface { type sparse struct { Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` - Up ml.Tensor `gguf:"ffn_up_exps.weight"` - Down ml.Tensor `gguf:"ffn_down_exps.weight"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts) + upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) hiddenStates = hiddenStates.SILU(ctx) hiddenStates = hiddenStates.Mul(ctx, upStates) - experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))