package convert import ( "bytes" "encoding/binary" "os" "path/filepath" "testing" "github.com/d4l3k/go-bfloat16" "github.com/google/go-cmp/cmp" "github.com/x448/float16" ) func TestSafetensors(t *testing.T) { t.Parallel() root, err := os.OpenRoot(t.TempDir()) if err != nil { t.Fatal(err) } defer root.Close() cases := []struct { name, dtype string offset, size int64 shape []uint64 setup func(*testing.T, *os.File) want []byte }{ { name: "fp32-fp32", dtype: "F32", size: 32 * 4, // 32 floats, each 4 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { f32s := make([]float32, 32) for i := range f32s { f32s[i] = float32(i) } if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, }, }, { name: "fp32-fp16", dtype: "F32", size: 32 * 4, // 32 floats, each 4 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { f32s := make([]float32, 32) for i := range f32s { f32s[i] = float32(i) } if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47, 0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b, 0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d, 0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f, }, }, { name: "fp16-fp16", dtype: "F16", size: 32 * 2, // 32 floats, each 2 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { u16s := make([]uint16, 32) for i := range u16s { u16s[i] = float16.Fromfloat32(float32(i)).Bits() } if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47, 0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b, 0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d, 0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f, }, }, { name: "fp16-fp32", dtype: "F16", size: 32 * 2, // 32 floats, each 2 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { u16s := make([]uint16, 32) for i := range u16s { u16s[i] = float16.Fromfloat32(float32(i)).Bits() } if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, }, }, { name: "bf16-bf16", dtype: "BF16", size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { f32s := make([]float32, 32) for i := range f32s { f32s[i] = float32(i) } if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40, 0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41, 0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41, 0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41, }, }, { name: "bf16-fp32", dtype: "BF16", size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { f32s := make([]float32, 32) for i := range f32s { f32s[i] = float32(i) } if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, }, }, { name: "u8-u8", dtype: "U8", size: 32, // 32 brain floats, each 1 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { u8s := make([]uint8, 32) for i := range u8s { u8s[i] = uint8(i) } if err := binary.Write(f, binary.LittleEndian, u8s); err != nil { t.Fatal(err) } }, want: []byte{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, }, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { path := filepath.Base(t.Name()) st := safetensor{ fs: root.FS(), path: path, dtype: tt.dtype, offset: tt.offset, size: tt.size, tensorBase: &tensorBase{ name: tt.name, shape: tt.shape, }, } f, err := root.Create(path) if err != nil { t.Fatal(err) } defer f.Close() tt.setup(t, f) var b bytes.Buffer if _, err := st.WriteTo(&b); err != nil { t.Fatal(err) } if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" { t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff) } }) } }