diff --git a/ml/backend.go b/ml/backend.go index d23537ce3..41679f3b3 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -49,7 +49,7 @@ type Context interface { FromIntSlice(s []int32, shape ...int) (Tensor, error) Forward(Tensor) - Compute(Tensor) Tensor + Compute(...Tensor) Close() } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index b0b4be14d..d039a3ea3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -23,7 +23,7 @@ import ( "github.com/ollama/ollama/ml" "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) type device struct { @@ -243,15 +243,17 @@ func (c *Context) Forward(t ml.Tensor) { C.ggml_build_forward_expand(c.graph, t.(*Tensor).t) } -func (c *Context) Compute(t ml.Tensor) ml.Tensor { - c.Forward(t) +func (c *Context) Compute(tensors ...ml.Tensor) { C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) - backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) + for _, t := range tensors { + if C.ggml_nbytes(t.(*Tensor).t) != 0 { + backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) - t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) - C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) - return t + t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) + C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + } + } } func shapeToGGML(shape []int) *C.int64_t { @@ -292,6 +294,13 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { n := len(s) + + if n == 0 { + var shape C.int64_t = 0 + t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) + return &Tensor{t: t}, nil + } + for _, v := range shape { n /= v } @@ -351,11 +360,7 @@ func (t *Tensor) Shape() []int { } func (t *Tensor) Bytes() []byte { - if bts := C.ggml_get_data(t.t); bts != nil { - return C.GoBytes(bts, C.int(C.ggml_nbytes(t.t))) - } - - return nil + return t.data } func (t *Tensor) Floats() (f32s []float32) { diff --git a/model/model.go b/model/model.go index 8f177f0eb..9290b6d30 100644 --- a/model/model.go +++ b/model/model.go @@ -275,5 +275,8 @@ func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) { } defer ctx.Close() - return ctx.Compute(t), nil + ctx.Forward(t) + ctx.Compute(t) + + return t, nil }