This commit is contained in:
Michael Yang
2025-11-03 12:55:03 -08:00
parent 702d5c71c9
commit bef18cd700

View File

@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :]
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx)
// inputPerLayer = inputsPerLayer[:, i, :].squeeze(1)
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
}
// hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1)
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1)
active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
predictions0 := predictions.Slice(ctx, 2, 0, 1, 1)
return predictions0.Concat(ctx, active, 2)
}
@@ -319,7 +319,7 @@ type TextOptions struct {
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1)
}
func (o *TextOptions) headDim() int {