mirror of
https://github.com/ollama/ollama.git
synced 2025-08-27 19:38:41 +02:00
convert(gptoss): mxfp4 to ggml layout to avoid jit conversion (#12018)
* convert: return bytes written * ggml flavor mxfp4 * simplify jit conversion * comment
This commit is contained in:
@@ -172,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
@@ -206,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
return int64(len(u8s)), nil
|
||||
}
|
||||
|
@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
|
||||
const (
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindMXFP4 = 4
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
|
@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindFP32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
@@ -290,24 +290,24 @@ func (t Tensor) blockSize() uint64 {
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
4, // MXFP4
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -330,8 +330,6 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeMXFP4, 39:
|
||||
return 1 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
@@ -382,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
@@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
@@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
@@ -191,8 +191,8 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
@@ -226,6 +226,7 @@ const (
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
@@ -318,7 +319,7 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case TensorTypeMXFP4:
|
||||
case 4, TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
|
@@ -535,6 +535,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
const BS = 17 // MXFP4 block size
|
||||
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
||||
var s uint64
|
||||
var tmp [16]byte
|
||||
for s < t.Size() {
|
||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||
if err := ctx.Err(); err != nil {
|
||||
@@ -546,37 +547,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
return err
|
||||
}
|
||||
for j := range n / BS {
|
||||
for i := 1; i < BS; i++ {
|
||||
// swap nibbles
|
||||
t_lo := bts[j*BS+i] & 0x0F
|
||||
t_hi := bts[j*BS+i] & 0xF0
|
||||
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
|
||||
}
|
||||
// transform aaaa...bbbb... to abababab...
|
||||
oi := 0
|
||||
tmp := [16]byte{}
|
||||
for i := 1; i < 9; i++ {
|
||||
blk_a0 := bts[j*BS+i] & 0xF0
|
||||
blk_a1 := bts[j*BS+i] << 4
|
||||
blk_b0 := bts[j*BS+i+8] >> 4
|
||||
blk_b1 := bts[j*BS+i+8] & 0x0F
|
||||
// swap once more
|
||||
out0 := blk_a0 | blk_b0
|
||||
out1 := blk_a1 | blk_b1
|
||||
out_h0 := out0 & 0xF0
|
||||
out_l0 := out0 & 0x0F
|
||||
out_h1 := out1 & 0xF0
|
||||
out_l1 := out1 & 0x0F
|
||||
out0 = (out_h0 >> 4) | (out_l0 << 4)
|
||||
out1 = (out_h1 >> 4) | (out_l1 << 4)
|
||||
tmp[oi] = out0
|
||||
oi++
|
||||
tmp[oi] = out1
|
||||
oi++
|
||||
}
|
||||
for i := range tmp {
|
||||
bts[j*BS+i+1] = tmp[i]
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[j*BS+i], bts[j*BS+i+8]
|
||||
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
copy(bts[j*BS+1:j*BS+17], tmp[:])
|
||||
}
|
||||
|
||||
for _, tt := range tts {
|
||||
|
Reference in New Issue
Block a user