mirror of
https://github.com/ollama/ollama.git
synced 2025-08-29 21:41:40 +02:00
gptoss: enable flash attention by default (#11996)
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
@@ -479,7 +480,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
@@ -677,7 +678,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@@ -773,6 +779,13 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
|
@@ -195,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
var kvct string
|
||||
if envconfig.FlashAttention() &&
|
||||
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention() {
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
|
@@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||
// that can handle it.
|
||||
fa := envconfig.FlashAttention()
|
||||
if f.FlashAttention() {
|
||||
slog.Info("model wants flash attention")
|
||||
fa = true
|
||||
}
|
||||
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
|
Reference in New Issue
Block a user