From 59412fbb436f85f62b41231c1df91d1ebe286431 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 26 Aug 2025 16:41:02 -0700 Subject: [PATCH] convert(gptoss): mxfp4 to ggml layout to avoid jit conversion (#12018) * convert: return bytes written * ggml flavor mxfp4 * simplify jit conversion * comment --- convert/convert_gptoss.go | 17 +++++++++++++++-- convert/reader.go | 2 +- convert/reader_safetensors.go | 6 +++--- fs/ggml/ggml.go | 36 +++++++++++++++++------------------ fs/ggml/type.go | 11 ++++++----- ml/backend/ggml/ggml.go | 35 ++++++---------------------------- 6 files changed, 49 insertions(+), 58 deletions(-) diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index c5a691d3d5..2048b18bed 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -172,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { blocksDims[i] = int(d) } - var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes())) + bts := b.Bytes() + var tmp [16]byte + for i := 0; i < b.Len(); i += 16 { + for j := range 8 { + // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc + a, b := bts[i+j], bts[i+j+8] + tmp[2*j+0] = (a & 0x0F) | (b << 4) + tmp[2*j+1] = (a >> 4) | (b & 0xF0) + } + + copy(bts[i:i+16], tmp[:]) + } + + var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts)) var s bytes.Buffer if _, err := m.scales.WriteTo(&s); err != nil { @@ -206,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { return 0, err } - return 0, nil + return int64(len(u8s)), nil } diff --git a/convert/reader.go b/convert/reader.go index 907d2a9ef8..b3f7a86603 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 { const ( tensorKindFP32 uint32 = iota tensorKindFP16 - tensorKindMXFP4 = 4 tensorKindBF16 = 30 + tensorKindMXFP4 = 39 ) func (t tensorBase) Kind() uint32 { diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index ccc5967328..7f029f933d 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { switch st.Kind() { case tensorKindFP32: - return 0, binary.Write(w, binary.LittleEndian, f32s) + return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s) case tensorKindFP16: f16s := make([]uint16, len(f32s)) for i := range f32s { f16s[i] = float16.Fromfloat32(f32s[i]).Bits() } - return 0, binary.Write(w, binary.LittleEndian, f16s) + return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s) case tensorKindBF16: u8s := bfloat16.EncodeFloat32(f32s) - return 0, binary.Write(w, binary.LittleEndian, u8s) + return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s) default: return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index feba55eddb..3f4374cd00 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -290,24 +290,24 @@ func (t Tensor) blockSize() uint64 { func (t TensorType) BlockSize() uint64 { switch t { case - 0, // F32 - 1, // F16 - 24, // I8 - 25, // I16 - 26, // I32 - 27, // I64 - 28, // F64 - 30: // BF16 + TensorTypeF32, + TensorTypeF16, + TensorTypeI8, + TensorTypeI16, + TensorTypeI32, + TensorTypeI64, + TensorTypeF64, + TensorTypeBF16: return 1 case - 2, // Q4_0 - 3, // Q4_1 - 4, // MXFP4 - 6, // Q5_0 - 7, // Q5_1 - 8, // Q8_0 - 9, // Q8_1 - 20: // IQ4_NL + TensorTypeQ4_0, + TensorTypeQ4_1, + TensorTypeQ5_0, + TensorTypeQ5_1, + TensorTypeQ8_0, + TensorTypeQ8_1, + tensorTypeIQ4_NL, + 4, TensorTypeMXFP4: return 32 default: return 256 @@ -330,8 +330,6 @@ func (t TensorType) TypeSize() uint64 { return 2 + blockSize/2 case TensorTypeQ4_1: return 2 + 2 + blockSize/2 - case TensorTypeMXFP4, 39: - return 1 + blockSize/2 case TensorTypeQ5_0: return 2 + 4 + blockSize/2 case TensorTypeQ5_1: @@ -382,6 +380,8 @@ func (t TensorType) TypeSize() uint64 { return blockSize/8 + blockSize/16 + blockSize/32 case TensorTypeBF16: return 2 + case 4, TensorTypeMXFP4: + return 1 + blockSize/2 default: return 0 } diff --git a/fs/ggml/type.go b/fs/ggml/type.go index 3e5deb87bd..1a31a5fd8d 100644 --- a/fs/ggml/type.go +++ b/fs/ggml/type.go @@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType { return TensorTypeQ4_0 case fileTypeQ4_1: return TensorTypeQ4_1 - case fileTypeMXFP4: - return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 case FileTypeQ8_0: return TensorTypeQ8_0 case fileTypeQ5_0: @@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType { return TensorTypeQ2_K case FileTypeBF16: return TensorTypeBF16 + case fileTypeMXFP4: + return TensorTypeMXFP4 default: slog.Warn("unsupported file type", "type", ftype) return 0 // F32 @@ -191,8 +191,8 @@ const ( TensorTypeF16 TensorTypeQ4_0 TensorTypeQ4_1 - TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 - tensorTypeQ4_3 // unused by GGML + tensorTypeQ4_2 + tensorTypeQ4_3 // unused by GGML TensorTypeQ5_0 TensorTypeQ5_1 TensorTypeQ8_0 @@ -226,6 +226,7 @@ const ( tensorTypeIQ4_NL_4_4 // unused by GGML tensorTypeIQ4_NL_4_8 // unused by GGML tensorTypeIQ4_NL_8_8 // unused by GGML + TensorTypeMXFP4 ) // ParseFileType parses the provided GGUF file type @@ -318,7 +319,7 @@ func (t TensorType) String() string { return "F64" case TensorTypeBF16: return "BF16" - case TensorTypeMXFP4: + case 4, TensorTypeMXFP4: return "MXFP4" default: return "unknown" diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 13d898aadd..e8403e06c2 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -535,6 +535,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { const BS = 17 // MXFP4 block size bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned var s uint64 + var tmp [16]byte for s < t.Size() { // Stop if either the parent context has been canceled or if any of the other tensors returned an error if err := ctx.Err(); err != nil { @@ -546,37 +547,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { return err } for j := range n / BS { - for i := 1; i < BS; i++ { - // swap nibbles - t_lo := bts[j*BS+i] & 0x0F - t_hi := bts[j*BS+i] & 0xF0 - bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4) - } - // transform aaaa...bbbb... to abababab... - oi := 0 - tmp := [16]byte{} for i := 1; i < 9; i++ { - blk_a0 := bts[j*BS+i] & 0xF0 - blk_a1 := bts[j*BS+i] << 4 - blk_b0 := bts[j*BS+i+8] >> 4 - blk_b1 := bts[j*BS+i+8] & 0x0F - // swap once more - out0 := blk_a0 | blk_b0 - out1 := blk_a1 | blk_b1 - out_h0 := out0 & 0xF0 - out_l0 := out0 & 0x0F - out_h1 := out1 & 0xF0 - out_l1 := out1 & 0x0F - out0 = (out_h0 >> 4) | (out_l0 << 4) - out1 = (out_h1 >> 4) | (out_l1 << 4) - tmp[oi] = out0 - oi++ - tmp[oi] = out1 - oi++ - } - for i := range tmp { - bts[j*BS+i+1] = tmp[i] + // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc + a, b := bts[j*BS+i], bts[j*BS+i+8] + tmp[2*(i-1)] = (a & 0x0F) | (b << 4) + tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0) } + copy(bts[j*BS+1:j*BS+17], tmp[:]) } for _, tt := range tts {