diff --git a/kvcache/causal.go b/kvcache/causal.go index 5d46f8d4c..69068439e 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) } - ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2)))) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2)))) + ctx.Forward( + key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), + value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))), + ) } func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 874e47433..bbbdf8368 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) out, _, mask := cache.Get(context) - context.Forward(out) - context.Forward(mask) - context.Compute(out, mask) + context.Forward(out, mask).Compute(out, mask) if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 8a44c194b..b85b1046a 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) } - ctx.Forward(key.Copy(ctx, c.keys[c.curLayer])) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer])) + ctx.Forward( + key.Copy(ctx, c.keys[c.curLayer]), + value.Copy(ctx, c.values[c.curLayer]), + ) } func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { diff --git a/ml/backend.go b/ml/backend.go index a742ee5c0..07bc75b64 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -65,7 +65,7 @@ type Context interface { FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) - Forward(Tensor) + Forward(...Tensor) Context Compute(...Tensor) MaxTensors() int Close() @@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string { if t.Bytes() == nil { - ctx.Forward(t) - ctx.Compute(t) + ctx.Forward(t).Compute(t) } s := make(S, mul(t.Shape()...)) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2d7cf340e..7f91990c3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -256,12 +256,16 @@ type Context struct { nodes int } -func (c *Context) Forward(t ml.Tensor) { +func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { if c.graph == nil { c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false) } - C.ggml_build_forward_expand(c.graph, t.(*Tensor).t) + for _, tensor := range tensors { + C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t) + } + + return c } func (c *Context) Compute(tensors ...ml.Tensor) { diff --git a/model/model.go b/model/model.go index 0b5996d9f..16020b354 100644 --- a/model/model.go +++ b/model/model.go @@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { return nil, err } - ctx.Forward(t) - ctx.Compute(t) + ctx.Forward(t).Compute(t) return t, nil }