From 85ccf7354dd8b32862e1a27398780094504c7fd8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 26 Aug 2025 13:34:45 -0700 Subject: [PATCH] gptoss: enable flash attention by default (#11996) --- fs/ggml/ggml.go | 15 ++++++++++++++- llm/memory.go | 10 ++++++---- llm/server.go | 5 +++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index a739e99ba9..dca0187b0a 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -10,6 +10,7 @@ import ( "slices" "strings" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/util/bufioutil" ) @@ -479,7 +480,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { }, nil } -func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { +func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) { context *= uint64(numParallel) embedding := f.KV().EmbeddingLength() @@ -677,7 +678,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri kv[i] *= context } } + partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 + if useFlashAttention { + // rough estimate of graph size with flash attention on + partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte + } } return @@ -773,6 +779,13 @@ func (f GGML) SupportsFlashAttention() bool { return headCountK != 0 && headCountV != 0 && headCountK == headCountV } +// FlashAttention checks if the model should enable flash attention +func (f GGML) FlashAttention() bool { + return slices.Contains([]string{ + "gptoss", "gpt-oss", + }, f.KV().String("general.architecture")) +} + // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type func kvCacheBytesPerElement(cacheType string) float64 { switch cacheType { diff --git a/llm/memory.go b/llm/memory.go index d8ae5e44ad..ce128eb585 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -195,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Warn("model missing blk.0 layer size") } - var kvct string - if envconfig.FlashAttention() && + useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) && discover.GetGPUInfo().FlashAttentionSupported() && - f.SupportsFlashAttention() { + f.SupportsFlashAttention() + + var kvct string + if useFlashAttention { requested := strings.ToLower(envconfig.KvCacheType()) if requested != "" && f.SupportsKVCacheType(requested) { kvct = requested } } - kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct) + kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention) if len(kv) > 0 { layerSize += kv[0] diff --git a/llm/server.go b/llm/server.go index b05e9b82da..30cf5c3609 100644 --- a/llm/server.go +++ b/llm/server.go @@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // that can handle it. fa := envconfig.FlashAttention() + if f.FlashAttention() { + slog.Info("model wants flash attention") + fa = true + } + if fa && !gpus.FlashAttentionSupported() { slog.Warn("flash attention enabled but not supported by gpu") fa = false