ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
This commit is contained in:
Jesse Gross 2025-02-05 13:18:36 -08:00 committed by Jesse Gross
parent 01d9a46854
commit 60830695c2

View File

@ -9,8 +9,6 @@ package ggml
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log/slog"
@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
for _, t := range tensors {
if C.ggml_nbytes(t.(*Tensor).t) != 0 {
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
needSync := true
sync := func() {
if needSync {
C.ggml_backend_sched_synchronize(c.sched)
needSync = false
}
}
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
for _, t := range tensors {
if C.ggml_nbytes(t.(*Tensor).t) > 0 {
t.(*Tensor).sync = sync
}
}
}
@ -330,7 +333,7 @@ func (c *Context) Close() {
type Tensor struct {
t *C.struct_ggml_tensor
data []byte
sync func()
}
func (t *Tensor) LogValue() slog.Value {
@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
return shape
}
func (t *Tensor) Bytes() []byte {
return t.data
func (t *Tensor) Bytes() (data []byte) {
if t.sync != nil {
data = make([]byte, C.ggml_nbytes(t.t))
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
}
return
}
func (t *Tensor) Floats() (f32s []float32) {
if t.data != nil {
f32s = make([]float32, C.ggml_nelements(t.t))
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
func (t *Tensor) Floats() (data []float32) {
if t.sync != nil {
data = make([]float32, C.ggml_nelements(t.t))
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
}
return