mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
models: Prune unused outputs earlier in the forward pass
Currently Rows is called as the last step in a model computation to get the values for the output tokens. However, if we move it earlier in the process then we can trim out computations that never get used. This is similar to how models are defined in llama.cpp. Changing the model definition in this way improves token generation performance by approximately 8%.
This commit is contained in:
parent
e5bcc51ae1
commit
5c5535c064
@ -120,11 +120,19 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct {
|
||||
MLP *TextMLP
|
||||
}
|
||||
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct {
|
||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||
}
|
||||
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
||||
}
|
||||
|
||||
type TextDecoderLayer interface {
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextDecoder struct {
|
||||
Layers []TextDecoderLayer
|
||||
}
|
||||
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
for i, layer := range d.Layers {
|
||||
layerType := selfAttentionLayer
|
||||
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
||||
@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
||||
cache.SetLayerType(layerType)
|
||||
|
||||
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(d.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,9 +218,9 @@ type TextModel struct {
|
||||
*TextModelOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user