diff --git a/kvcache/causal.go b/kvcache/causal.go index 1d4daf809..b2e7b3ab0 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -90,6 +90,14 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { c.config.CachePadding = 1 } + if c.config.MaskBatchPadding == 0 { + c.config.MaskBatchPadding = 1 + } + + if c.config.MaskDType == ml.DTypeOther { + c.config.MaskDType = ml.DTypeF32 + } + c.DType = dtype c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) c.cells = make([]cacheCell, c.Capacity) @@ -192,13 +200,14 @@ func roundUp(length, pad int) int { // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do mask padding, which is required for flash attention - // Align and pad the cache range as required by the backend + // Align and pad the two dimensions as required by the backend + batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 length := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*length) + mask := make([]float32, batchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { @@ -209,7 +218,24 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te } } - return ctx.FromFloatSlice(mask, length, c.curBatchSize) + // Mask out any padding tokens we added. For padding that we added to the cache history, this + // has already been masked out because the sequence doesn't match. + for i := c.curBatchSize * length; i < len(mask); i++ { + mask[i] = float32(math.Inf(-1)) + } + + maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize) + if err != nil { + return nil, err + } + + if c.config.MaskDType != ml.DTypeF32 { + out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...) + ctx.Forward(maskTensor.Copy(ctx, out)) + maskTensor = out + } + + return maskTensor, nil } func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { diff --git a/ml/backend.go b/ml/backend.go index de2725c02..83b7a8c9c 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -46,6 +46,14 @@ type CacheConfig struct { // and return the permuted version via Get. This uses the cache copy operation // to avoid a Contiguous call on the permuted tensor. PermutedV bool + + // MaskDType specifies the data type for generating the mask. If unset it will + // default to DTypeF32. + MaskDType DType + + // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. + // Any position that does not correspond to an actual token will be filled with -Inf. + MaskBatchPadding int } // BackendParams controls how the backend loads and executes models @@ -61,6 +69,9 @@ type BackendParams struct { // TensorSplit is the fraction of the model to offload to each GPU TensorSplit []float32 + + // FlashAttention indicates that we should use a fused flash attention kernel + FlashAttention bool } var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2c7e856cc..f4948fcad 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device { }) type Backend struct { + flashAttention bool + meta *fs.GGML cpus, gpus []Context tensors map[string]*Context @@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { } return &Backend{ - meta: meta, - cpus: cpus, - gpus: gpus, + flashAttention: params.FlashAttention, + meta: meta, + cpus: cpus, + gpus: gpus, sched: C.ggml_backend_sched_new( (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), @@ -248,7 +251,11 @@ func (b *Backend) NewContext() ml.Context { } func (b *Backend) CacheConfig() ml.CacheConfig { - return ml.CacheConfig{CachePadding: 32, PermutedV: true} + if b.flashAttention { + return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} + } else { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} + } } type Context struct { @@ -705,14 +712,22 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T query := t.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) - kq := key.MulmatFullPrec(ctx, query) - kq = &Tensor{ - b: t.b, - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } + if t.b.flashAttention { + value = value.Permute(ctx, 0, 2, 1, 3) - kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) + C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) + return &Tensor{b: t.b, t: kqv} + } else { + kq := key.MulmatFullPrec(ctx, query) + kq = &Tensor{ + b: t.b, + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } } func (b *Backend) SystemInfo() string { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index db9b271e5..5705931ad 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -818,7 +818,7 @@ func Execute(args []string) error { batchSize := fs.Int("batch-size", 512, "Batch size") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU") - _ = fs.Bool("flash-attn", false, "Enable flash attention") + flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") @@ -863,7 +863,6 @@ func Execute(args []string) error { } // TODO(jessegross): Parameters that need to be implemented: - // flash-attn // no-mmap // mlock @@ -878,10 +877,11 @@ func Execute(args []string) error { } params := ml.BackendParams{ - NumThreads: *threads, - NumGPULayers: *numGPULayers, - MainGPU: *mainGPU, - TensorSplit: tensorSplitFloats, + NumThreads: *threads, + NumGPULayers: *numGPULayers, + MainGPU: *mainGPU, + TensorSplit: tensorSplitFloats, + FlashAttention: *flashAttention, } server.ready.Add(1)