ml: update Context.Forward interface

update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
This commit is contained in:
Michael Yang 2025-02-21 11:57:08 -08:00
parent 41dc280491
commit 3e8b8a1933
6 changed files with 18 additions and 14 deletions

View File

@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
)
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {

View File

@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
out, _, mask := cache.Get(context)
context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)

View File

@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer]),
value.Copy(ctx, c.values[c.curLayer]),
)
}
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {

View File

@ -65,7 +65,7 @@ type Context interface {
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(Tensor)
Forward(...Tensor) Context
Compute(...Tensor)
MaxTensors() int
Close()
@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
if t.Bytes() == nil {
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
}
s := make(S, mul(t.Shape()...))

View File

@ -256,12 +256,16 @@ type Context struct {
nodes int
}
func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil {
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
}
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
for _, tensor := range tensors {
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
}
return c
}
func (c *Context) Compute(tensors ...ml.Tensor) {

View File

@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
return nil, err
}
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
return t, nil
}