feat: add support for flash_attn (#4120)

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: add flash_attn support
This commit is contained in:
Sam
2024-05-21 06:36:03 +10:00
committed by GitHub
parent ccdf0b2a44
commit e15307fdf4
2 changed files with 28 additions and 3 deletions

View File

@@ -200,6 +200,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}
flashAttnSupported := true
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 {
flashAttnSupported = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnSupported = false
}
}
if flashAttnSupported {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet