use nn.Linear in place of ml.Tensor (#11049)

while nn.Linear.Forward isn't applicable for sparse MLP, it's still
a nice container for the tensors
This commit is contained in:
Michael Yang
2025-06-11 12:10:15 -07:00
committed by GitHub
parent deaabe292d
commit 2e77aa1ae7
2 changed files with 12 additions and 12 deletions

View File

@@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {

View File

@@ -66,9 +66,9 @@ type MLP interface {
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))