ml: Add support for quantized KV cache

Similar to the llama engine, quantizing the KV cache requires
flash attention to be enabled through the Ollama server.
This commit is contained in:
Jesse Gross 2025-02-21 20:54:14 -08:00
parent 7b5963513c
commit 5beede47d9
3 changed files with 13 additions and 3 deletions

View File

@ -194,7 +194,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16:
case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
@ -262,5 +262,7 @@ const (
DTypeOther DType = iota
DTypeF32
DTypeF16
DTypeQ80
DTypeQ40
DTypeI32
)

View File

@ -328,6 +328,10 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeQ80:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_Q8_0, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeQ40:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_Q4_0, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
@ -437,6 +441,10 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeF32
case C.GGML_TYPE_F16:
return ml.DTypeF16
case C.GGML_TYPE_Q8_0:
return ml.DTypeQ80
case C.GGML_TYPE_Q4_0:
return ml.DTypeQ40
case C.GGML_TYPE_I32:
return ml.DTypeI32
default:

View File

@ -62,9 +62,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType {
switch s {
case "q8_0":
panic("kv cache quantization not yet implemented")
return ml.DTypeQ80
case "q4_0":
panic("kv cache quantization not yet implemented")
return ml.DTypeQ40
default:
return ml.DTypeF16
}