fix drift from main

This commit is contained in:
Jesse Gross 2025-03-07 09:30:10 -08:00 committed by Michael Yang
parent 4b037a97dc
commit 4346c2409d
5 changed files with 49 additions and 22 deletions

View File

@ -441,6 +441,10 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}

View File

@ -64,6 +64,7 @@ func New(c ml.Config) (model.Model, error) {
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
return &m, nil
}
@ -84,7 +85,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads)))
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
@ -99,8 +100,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
@ -144,12 +145,20 @@ type Layer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@ -170,6 +179,11 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@ -182,7 +196,13 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@ -192,12 +212,6 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}

View File

@ -66,9 +66,6 @@ func newTextModel(c ml.Config) *TextModel {
},
}
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m
}
@ -145,12 +142,20 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@ -181,7 +186,13 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@ -190,7 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs)
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
}

View File

@ -53,7 +53,7 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`

View File

@ -73,7 +73,7 @@ func TestSentencePieceEncode(t *testing.T) {
}
for _, want := range cases {
ids, err := tokenizer.Encode(want)
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
@ -98,7 +98,7 @@ func TestSentencePieceEncode(t *testing.T) {
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token)
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}