diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8b33d38fd..6eba3c602 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -9,8 +9,6 @@ package ggml import "C" import ( - "bytes" - "encoding/binary" "fmt" "io" "log/slog" @@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) { C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) - for _, t := range tensors { - if C.ggml_nbytes(t.(*Tensor).t) != 0 { - backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) + needSync := true + sync := func() { + if needSync { + C.ggml_backend_sched_synchronize(c.sched) + needSync = false + } + } - t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) - C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + for _, t := range tensors { + if C.ggml_nbytes(t.(*Tensor).t) > 0 { + t.(*Tensor).sync = sync } } } @@ -330,7 +333,7 @@ func (c *Context) Close() { type Tensor struct { t *C.struct_ggml_tensor - data []byte + sync func() } func (t *Tensor) LogValue() slog.Value { @@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int { return shape } -func (t *Tensor) Bytes() []byte { - return t.data +func (t *Tensor) Bytes() (data []byte) { + if t.sync != nil { + data = make([]byte, C.ggml_nbytes(t.t)) + + t.sync() + C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t)) + } + + return } -func (t *Tensor) Floats() (f32s []float32) { - if t.data != nil { - f32s = make([]float32, C.ggml_nelements(t.t)) - _ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s) +func (t *Tensor) Floats() (data []float32) { + if t.sync != nil { + data = make([]float32, C.ggml_nelements(t.t)) + + t.sync() + C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t)) } return