diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 617f53635..b1dc7d779 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -362,7 +362,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { } func (c *testContext) Input() ml.Context { return c } -func (c *testContext) Output() ml.Context { return c } func (c *testContext) Layer(int) ml.Context { return c } func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } diff --git a/ml/backend.go b/ml/backend.go index 354faf432..cfb18d6a9 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -110,12 +110,10 @@ type Context interface { MaxGraphNodes() int Close() - // Input returns a context appropriate for creating input tensors + // Input returns a context appropriate for creating tensors that are + // inputs to the model (which includes things like output locations) Input() Context - // Output returns a context appropriate for creating output tensors - Output() Context - // Layer returns a context appropriate for creating intermediate tensors Layer(int) Context } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f6b017748..b6f59ae0e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -48,9 +48,6 @@ type Backend struct { // input is the backend used for inputs input *C.struct_ggml_backend_buffer_type - // output is the backend used for outputs - output *C.struct_ggml_backend_buffer_type - // layers is the backend used for repeating layers layers map[int]*C.struct_ggml_backend_buffer_type @@ -400,8 +397,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, C.size_t(maxGraphNodes), C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), ), - input: deviceBufferTypes[input.d], - output: deviceBufferTypes[output.d], + input: deviceBufferTypes[input.d], layers: func() map[int]*C.struct_ggml_backend_buffer_type { m := make(map[int]*C.struct_ggml_backend_buffer_type) for i, layer := range layers { @@ -482,19 +478,6 @@ func (c Context) Input() ml.Context { return &c } -func (c Context) Output() ml.Context { - if c.b.output != nil { - return &Context{ - b: c.b, - ctx: c.ctx, - buft: c.b.output, - maxGraphNodes: c.maxGraphNodes, - } - } - - return &c -} - func (c Context) Layer(i int) ml.Context { if buft, ok := c.b.layers[i]; ok { return &Context{