diff --git a/ml/backend.go b/ml/backend.go index 2f049924a..50534643f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -194,7 +194,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) - case DTypeF16: + case DTypeF16, DTypeQ80, DTypeQ40: f32 := ctx.Zeros(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { @@ -262,5 +262,7 @@ const ( DTypeOther DType = iota DTypeF32 DTypeF16 + DTypeQ80 + DTypeQ40 DTypeI32 ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index cfc47a58c..838f4acfb 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -328,6 +328,10 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeF16: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) + case ml.DTypeQ80: + t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_Q8_0, C.int(len(shape)), shapeToGGML(shape)) + case ml.DTypeQ40: + t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_Q4_0, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeI32: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) default: @@ -437,6 +441,10 @@ func (t *Tensor) DType() ml.DType { return ml.DTypeF32 case C.GGML_TYPE_F16: return ml.DTypeF16 + case C.GGML_TYPE_Q8_0: + return ml.DTypeQ80 + case C.GGML_TYPE_Q4_0: + return ml.DTypeQ40 case C.GGML_TYPE_I32: return ml.DTypeI32 default: diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index e1fa98b1a..1b12173e1 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -62,9 +62,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots func kvCacheTypeFromStr(s string) ml.DType { switch s { case "q8_0": - panic("kv cache quantization not yet implemented") + return ml.DTypeQ80 case "q4_0": - panic("kv cache quantization not yet implemented") + return ml.DTypeQ40 default: return ml.DTypeF16 }