mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
ml: update Context.Forward interface
update Context.Forward to accept multiple tensors to match Context.Compute signature update Context.Forward to return Context such that it can be chained with Context.Compute
This commit is contained in:
parent
41dc280491
commit
3e8b8a1933
@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
|
||||
ctx.Forward(
|
||||
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
|
||||
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
|
@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
|
||||
context.Forward(out)
|
||||
context.Forward(mask)
|
||||
context.Compute(out, mask)
|
||||
context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||
|
@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
|
||||
ctx.Forward(
|
||||
key.Copy(ctx, c.keys[c.curLayer]),
|
||||
value.Copy(ctx, c.values[c.curLayer]),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
|
@ -65,7 +65,7 @@ type Context interface {
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
Close()
|
||||
@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
|
||||
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
if t.Bytes() == nil {
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
|
@ -256,12 +256,16 @@ type Context struct {
|
||||
nodes int
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
}
|
||||
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
for _, tensor := range tensors {
|
||||
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
|
@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user