From 5c5535c0648fb12b174246eb2524e862ae2d2d5b Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 18 Feb 2025 17:16:57 -0800 Subject: [PATCH] models: Prune unused outputs earlier in the forward pass Currently Rows is called as the last step in a model computation to get the values for the output tokens. However, if we move it earlier in the process then we can trim out computations that never get used. This is similar to how models are defined in llama.cpp. Changing the model definition in this way improves token generation performance by approximately 8%. --- model/models/llama/model.go | 36 ++++++++++++++++++++----------- model/models/mllama/model.go | 6 ++---- model/models/mllama/model_text.go | 27 +++++++++++++++++------ 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index b2c5c2c7b..e90631fb8 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -120,11 +120,19 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - - for i, layer := range m.Layers { - m.Cache.SetLayer(i) - hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) - } - - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - hiddenState = m.Output.Forward(ctx, hiddenState) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil } func init() { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index a1460d940..f5521ce5c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - // TODO: attention mask, cross attention mask - hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + // TODO: attention mask, cross attention mask + return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1e48086a3..8ad804cf9 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct { MLP *TextMLP } -func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct { MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` } -func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, } type TextDecoderLayer interface { - Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor + Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor } type TextDecoder struct { Layers []TextDecoderLayer } -func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { layerType := selfAttentionLayer if slices.Contains(opts.crossAttentionLayers, uint32(i)) { @@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr cache.SetLayerType(layerType) if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { - hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) + var lastLayerOutputs ml.Tensor + if i == len(d.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts) } } @@ -205,9 +218,9 @@ type TextModel struct { *TextModelOptions } -func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) - hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) + hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) return m.Output.Forward(ctx, hiddenState) }