diff --git a/kvcache/causal.go b/kvcache/causal.go index f6bacaaf8..b594d0b41 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -30,6 +30,11 @@ type Causal struct { // ** current forward pass ** + // curReserve indicates that this forward pass is only for + // memory reservation and we should not update our metadata + // based on it. + curReserve bool + // the active layer for Get and Put curLayer int @@ -159,12 +164,13 @@ func (c *Causal) Close() { } func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + c.curReserve = reserve c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - if !reserve { + if !c.curReserve { c.updateSlidingWindow() var err error @@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 length := c.curCellRange.max - c.curCellRange.min + 1 + + if c.curReserve { + return ctx.Input().Empty(c.config.MaskDType, length, batchSize) + } + mask := make([]float32, batchSize*length) for i := range c.curBatchSize {