From ea79003180205680000bacf97466fc9d78d71f5e Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 27 May 2025 13:33:57 -0700 Subject: [PATCH] kvcache: Skip computing causal mask for worst case graph reservation Computing an attention mask for a large context and max batch is expensive - over 100ms. Models like Gemma3 that have multiple types of caches and custom attention masks need to do this 4 times, so this adds approximately 500ms to startup time when using 128k context When we are reserving the worst case graph, we don't need the mask, only its shape, so we can skip this. --- kvcache/causal.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index f6bacaaf8..b594d0b41 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -30,6 +30,11 @@ type Causal struct { // ** current forward pass ** + // curReserve indicates that this forward pass is only for + // memory reservation and we should not update our metadata + // based on it. + curReserve bool + // the active layer for Get and Put curLayer int @@ -159,12 +164,13 @@ func (c *Causal) Close() { } func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + c.curReserve = reserve c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - if !reserve { + if !c.curReserve { c.updateSlidingWindow() var err error @@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 length := c.curCellRange.max - c.curCellRange.min + 1 + + if c.curReserve { + return ctx.Input().Empty(c.config.MaskDType, length, batchSize) + } + mask := make([]float32, batchSize*length) for i := range c.curBatchSize {