From 60830695c2c4db46925436c7ecdb4bc00febcdb3 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 5 Feb 2025 13:18:36 -0800 Subject: [PATCH] ggml-backend: Ensure data is available after async computation We need to sync before retrieving data after async computation. It is also important to ensure that the Go buffer is not moved by the GC across function calls so we do a synchronous copy. --- ml/backend/ggml/ggml.go | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8b33d38fd..6eba3c602 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -9,8 +9,6 @@ package ggml import "C" import ( - "bytes" - "encoding/binary" "fmt" "io" "log/slog" @@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) { C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) - for _, t := range tensors { - if C.ggml_nbytes(t.(*Tensor).t) != 0 { - backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) + needSync := true + sync := func() { + if needSync { + C.ggml_backend_sched_synchronize(c.sched) + needSync = false + } + } - t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) - C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + for _, t := range tensors { + if C.ggml_nbytes(t.(*Tensor).t) > 0 { + t.(*Tensor).sync = sync } } } @@ -330,7 +333,7 @@ func (c *Context) Close() { type Tensor struct { t *C.struct_ggml_tensor - data []byte + sync func() } func (t *Tensor) LogValue() slog.Value { @@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int { return shape } -func (t *Tensor) Bytes() []byte { - return t.data +func (t *Tensor) Bytes() (data []byte) { + if t.sync != nil { + data = make([]byte, C.ggml_nbytes(t.t)) + + t.sync() + C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t)) + } + + return } -func (t *Tensor) Floats() (f32s []float32) { - if t.data != nil { - f32s = make([]float32, C.ggml_nelements(t.t)) - _ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s) +func (t *Tensor) Floats() (data []float32) { + if t.sync != nil { + data = make([]float32, C.ggml_nelements(t.t)) + + t.sync() + C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t)) } return