mirror of
https://github.com/ollama/ollama.git
synced 2025-03-26 17:51:48 +01:00
ggml-backend: Ensure data is available after async computation
We need to sync before retrieving data after async computation. It is also important to ensure that the Go buffer is not moved by the GC across function calls so we do a synchronous copy.
This commit is contained in:
parent
01d9a46854
commit
60830695c2
@ -9,8 +9,6 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
||||
|
||||
for _, t := range tensors {
|
||||
if C.ggml_nbytes(t.(*Tensor).t) != 0 {
|
||||
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
||||
needSync := true
|
||||
sync := func() {
|
||||
if needSync {
|
||||
C.ggml_backend_sched_synchronize(c.sched)
|
||||
needSync = false
|
||||
}
|
||||
}
|
||||
|
||||
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
||||
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
for _, t := range tensors {
|
||||
if C.ggml_nbytes(t.(*Tensor).t) > 0 {
|
||||
t.(*Tensor).sync = sync
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,7 +333,7 @@ func (c *Context) Close() {
|
||||
|
||||
type Tensor struct {
|
||||
t *C.struct_ggml_tensor
|
||||
data []byte
|
||||
sync func()
|
||||
}
|
||||
|
||||
func (t *Tensor) LogValue() slog.Value {
|
||||
@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
|
||||
return shape
|
||||
}
|
||||
|
||||
func (t *Tensor) Bytes() []byte {
|
||||
return t.data
|
||||
func (t *Tensor) Bytes() (data []byte) {
|
||||
if t.sync != nil {
|
||||
data = make([]byte, C.ggml_nbytes(t.t))
|
||||
|
||||
t.sync()
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) Floats() (f32s []float32) {
|
||||
if t.data != nil {
|
||||
f32s = make([]float32, C.ggml_nelements(t.t))
|
||||
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
|
||||
func (t *Tensor) Floats() (data []float32) {
|
||||
if t.sync != nil {
|
||||
data = make([]float32, C.ggml_nelements(t.t))
|
||||
|
||||
t.sync()
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user