diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index ec038a287f..3a89afe72c 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac cache.(*kvcache.WrapperCache).SetLayerType(layerType) - // inputPerLayer = inputsPerLayer[:, i, :] - inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx) + // inputPerLayer = inputsPerLayer[:, i, :].squeeze(1) + inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)) hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions) } // hiddenStates = hiddenStates[:, :, 0] - hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1)) + hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1) targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx) targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1) // hiddenState = hiddenStates[:, :, 1:] - hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1) + hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1) altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState) altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx))) @@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) // inactive := predictions[:, :, 1:] - inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1) + inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1) active = inactive.Add(ctx, active) - predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1)) + predictions0 := predictions.Slice(ctx, 2, 0, 1, 1) return predictions0.Concat(ctx, active, 2) } @@ -319,7 +319,7 @@ type TextOptions struct { func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor { // t[:, :, o.altupActiveIndex] - return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1)) + return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1) } func (o *TextOptions) headDim() int {