From ef7d26ba2cd52a5295620067ab9abd9c0055a558 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 14 Aug 2025 15:03:57 -0700 Subject: [PATCH] convert: skip reading into memory when possible (#11507) if there's no transformation to the tensor and the input and output types match, copy directly into the writer. also read from a bufio with a 32K buffer --- convert/reader_safetensors.go | 39 ++++-- convert/reader_test.go | 232 ++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+), 13 deletions(-) create mode 100644 convert/reader_test.go diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index 63f31631dd..ccc5967328 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -1,6 +1,7 @@ package convert import ( + "bufio" "bytes" "encoding/binary" "encoding/json" @@ -124,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { } defer f.Close() - if seeker, ok := f.(io.Seeker); ok { - if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil { - return 0, err - } - } else { - if _, err := io.CopyN(io.Discard, f, st.offset); err != nil { - return 0, err + r, err := func() (io.Reader, error) { + if readerAt, ok := f.(io.ReaderAt); ok { + return io.NewSectionReader(readerAt, st.offset, st.size), nil + } else if seeker, ok := f.(io.Seeker); ok { + _, err := seeker.Seek(st.offset, io.SeekStart) + return f, err + } else { + _, err := io.CopyN(io.Discard, f, st.offset) + return f, err } + }() + if err != nil { + return 0, err + } + + br := bufio.NewReaderSize(r, min(32<<10, int(st.size))) + // special case when input and output are same type and the + // tensor doesn't need repacking + if (st.repacker == nil) && + ((st.dtype == "F32" && st.Kind() == tensorKindFP32) || + (st.dtype == "F16" && st.Kind() == tensorKindFP16) || + (st.dtype == "U8")) { + return io.CopyN(w, br, st.size) } var f32s []float32 switch st.dtype { case "F32": f32s = make([]float32, st.size/4) - if err = binary.Read(f, binary.LittleEndian, f32s); err != nil { + if err = binary.Read(br, binary.LittleEndian, f32s); err != nil { return 0, err } case "F16": u16s := make([]uint16, st.size/2) - if err = binary.Read(f, binary.LittleEndian, u16s); err != nil { + if err = binary.Read(br, binary.LittleEndian, u16s); err != nil { return 0, err } @@ -154,14 +170,11 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { case "BF16": u8s := make([]uint8, st.size) - if err = binary.Read(f, binary.LittleEndian, u8s); err != nil { + if err = binary.Read(br, binary.LittleEndian, u8s); err != nil { return 0, err } f32s = bfloat16.DecodeFloat32(u8s) - case "U8": - // U8 tensors do not support repacking or type conversion. - return io.CopyN(w, f, st.size) default: return 0, fmt.Errorf("unknown data type: %s", st.dtype) } diff --git a/convert/reader_test.go b/convert/reader_test.go new file mode 100644 index 0000000000..6dbe32a51d --- /dev/null +++ b/convert/reader_test.go @@ -0,0 +1,232 @@ +package convert + +import ( + "bytes" + "encoding/binary" + "os" + "path/filepath" + "testing" + + "github.com/d4l3k/go-bfloat16" + "github.com/google/go-cmp/cmp" + "github.com/x448/float16" +) + +func TestSafetensors(t *testing.T) { + t.Parallel() + + root, err := os.OpenRoot(t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer root.Close() + + cases := []struct { + name, + dtype string + offset, + size int64 + shape []uint64 + setup func(*testing.T, *os.File) + want []byte + }{ + { + name: "fp32-fp32", + dtype: "F32", + size: 32 * 4, // 32 floats, each 4 bytes + shape: []uint64{32}, + setup: func(t *testing.T, f *os.File) { + f32s := make([]float32, 32) + for i := range f32s { + f32s[i] = float32(i) + } + + if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, + 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, + 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, + 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, + 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, + 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, + 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, + 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, + }, + }, + { + name: "fp32-fp16", + dtype: "F32", + size: 32 * 4, // 32 floats, each 4 bytes + shape: []uint64{16, 2}, + setup: func(t *testing.T, f *os.File) { + f32s := make([]float32, 32) + for i := range f32s { + f32s[i] = float32(i) + } + + if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47, + 0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b, + 0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d, + 0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f, + }, + }, + { + name: "fp16-fp16", + dtype: "F16", + size: 32 * 2, // 32 floats, each 2 bytes + shape: []uint64{16, 2}, + setup: func(t *testing.T, f *os.File) { + u16s := make([]uint16, 32) + for i := range u16s { + u16s[i] = float16.Fromfloat32(float32(i)).Bits() + } + + if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47, + 0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b, + 0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d, + 0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f, + }, + }, + { + name: "fp16-fp32", + dtype: "F16", + size: 32 * 2, // 32 floats, each 2 bytes + shape: []uint64{32}, + setup: func(t *testing.T, f *os.File) { + u16s := make([]uint16, 32) + for i := range u16s { + u16s[i] = float16.Fromfloat32(float32(i)).Bits() + } + + if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, + 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, + 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, + 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, + 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, + 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, + 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, + 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, + }, + }, + { + name: "bf16-bf16", + dtype: "BF16", + size: 32 * 2, // 32 brain floats, each 2 bytes + shape: []uint64{16, 2}, + setup: func(t *testing.T, f *os.File) { + f32s := make([]float32, 32) + for i := range f32s { + f32s[i] = float32(i) + } + + if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40, + 0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41, + 0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41, + 0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41, + }, + }, + { + name: "bf16-fp32", + dtype: "BF16", + size: 32 * 2, // 32 brain floats, each 2 bytes + shape: []uint64{32}, + setup: func(t *testing.T, f *os.File) { + f32s := make([]float32, 32) + for i := range f32s { + f32s[i] = float32(i) + } + + if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, + 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, + 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, + 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, + 0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41, + 0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41, + 0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41, + 0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41, + }, + }, + { + name: "u8-u8", + dtype: "U8", + size: 32, // 32 brain floats, each 1 bytes + shape: []uint64{32}, + setup: func(t *testing.T, f *os.File) { + u8s := make([]uint8, 32) + for i := range u8s { + u8s[i] = uint8(i) + } + + if err := binary.Write(f, binary.LittleEndian, u8s); err != nil { + t.Fatal(err) + } + }, + want: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + path := filepath.Base(t.Name()) + st := safetensor{ + fs: root.FS(), + path: path, + dtype: tt.dtype, + offset: tt.offset, + size: tt.size, + tensorBase: &tensorBase{ + name: tt.name, + shape: tt.shape, + }, + } + + f, err := root.Create(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + tt.setup(t, f) + + var b bytes.Buffer + if _, err := st.WriteTo(&b); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" { + t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff) + } + }) + } +}