mirror of
https://github.com/ollama/ollama.git
synced 2025-07-28 18:03:05 +02:00
kvcache: Skip computing causal mask for worst case graph reservation
Computing an attention mask for a large context and max batch is expensive - over 100ms. Models like Gemma3 that have multiple types of caches and custom attention masks need to do this 4 times, so this adds approximately 500ms to startup time when using 128k context When we are reserving the worst case graph, we don't need the mask, only its shape, so we can skip this.
This commit is contained in:
@@ -30,6 +30,11 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
@@ -159,12 +164,13 @@ func (c *Causal) Close() {
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !reserve {
|
||||
if !c.curReserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
@@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
|
Reference in New Issue
Block a user