ml/backend/ggml: create tensor on specific backend

some tensors should be created on specific backends to reduce number of
copies and improve performance
This commit is contained in:
Michael Yang 2025-02-25 16:06:32 -08:00
parent 764e199d67
commit 7bae7fa5ce
6 changed files with 129 additions and 60 deletions

View File

@ -237,13 +237,13 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
@ -440,7 +440,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContext()
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {

View File

@ -106,7 +106,7 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContext()
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {

View File

@ -24,6 +24,7 @@ type Backend interface {
Config() Config
Get(name string) Tensor
NewContext() Context
NewContextSize(size int) Context
}
// BackendCacheConfig should be implemented by backends that need special output
@ -101,6 +102,15 @@ type Context interface {
Compute(...Tensor)
MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating input tensors
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
}
type Tensor interface {

View File

@ -41,16 +41,14 @@ func devices() iter.Seq[*C.struct_ggml_backend_device] {
}
type Backend struct {
meta *fs.GGML
meta *fs.GGML
sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor
input *C.struct_ggml_backend
output *C.struct_ggml_backend
layers map[int]*C.struct_ggml_backend
flashAttention bool
sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor
ctxs []*C.struct_ggml_context
backends []*C.struct_ggml_backend
bufts []*C.struct_ggml_backend_buffer_type
}
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
@ -118,7 +116,6 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
input := dbt{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
slog.Info("input layer", "device", C.GoString(C.ggml_backend_dev_name(input.d)))
var blocks int
for key, value := range meta.KV() {
@ -136,18 +133,14 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
layers := make([]dbt, blocks)
for i := range layers {
layers[i] = gpuBufferTypes[slices.IndexFunc(splits, indexFunc(i))]
slog.Info("layer", "i", i, "device", C.GoString(C.ggml_backend_dev_name(layers[i].d)))
}
output := gpuBufferTypes[slices.IndexFunc(splits, indexFunc(blocks))]
slog.Info("output layer", "device", C.GoString(C.ggml_backend_dev_name(output.d)))
maxTensors := len(meta.Tensors().Items())
maxTensors += 1
maxTensors += blocks * 2
slog.Info("max tensors", "max_tensors", maxTensors)
type tensor struct {
source *fs.Tensor
target string
@ -242,7 +235,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
for bs := range maps.Values(bbs) {
for _, b := range bs {
slog.Info("model", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
}
}
@ -290,11 +283,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return nil, err
}
deviceBackends := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend)
var backends []*C.struct_ggml_backend
var bufts []*C.struct_ggml_backend_buffer_type
for _, d := range append(gpus, append(accels, cpus...)...) {
b := C.ggml_backend_dev_init(d, nil)
backends = append(backends, b)
deviceBackends[d] = b
bt := C.ggml_backend_get_default_buffer_type(b)
if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
@ -305,13 +300,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
bufts = append(bufts, bt)
slog.Info("compute buffer", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
}
return &Backend{
flashAttention: params.FlashAttention,
meta: meta,
tensors: tensors,
meta: meta,
tensors: tensors,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@ -319,6 +314,15 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
true,
),
input: deviceBackends[input.d],
output: deviceBackends[output.d],
layers: func() map[int]*C.struct_ggml_backend {
m := make(map[int]*C.struct_ggml_backend)
for i, layer := range layers {
m[i] = deviceBackends[layer.d]
}
return m
}(),
}, nil
}
@ -339,15 +343,21 @@ func (b *Backend) Get(name string) ml.Tensor {
}
func (b *Backend) NewContext() ml.Context {
maxGraphNodes := max(8192, len(b.meta.Tensors().Items())*5)
return b.NewContextSize(max(8192, len(b.meta.Tensors().Items())*5))
}
func (b *Backend) NewContextSize(n int) ml.Context {
return &Context{
b: b,
b: b,
ctx: C.ggml_init(C.struct_ggml_init_params{
mem_size: C.size_t(maxGraphNodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxGraphNodes), false),
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
no_alloc: true,
}),
backend: C.ggml_backend_sched_get_backend(b.sched, 0),
maxGraphNodes: maxGraphNodes,
maxGraphNodes: n,
input: b.input,
output: b.output,
layers: b.layers,
}
}
@ -364,11 +374,61 @@ type Context struct {
ctx *C.struct_ggml_context
graph *C.struct_ggml_cgraph
// backend is the backend used for new tensors
backend *C.struct_ggml_backend
// input is the backend used for inputs
input *C.struct_ggml_backend
// output is the backend used for outputs
output *C.struct_ggml_backend
// output is the backend used for repeating layers
layers map[int]*C.struct_ggml_backend
maxGraphNodes int
}
func (c *Context) Input() ml.Context {
if c.input != nil {
return &Context{
b: c.b,
ctx: c.ctx,
backend: c.input,
maxGraphNodes: c.maxGraphNodes,
}
}
return c
}
func (c *Context) Output() ml.Context {
if c.output != nil {
return &Context{
b: c.b,
ctx: c.ctx,
backend: c.output,
maxGraphNodes: c.maxGraphNodes,
}
}
return c
}
func (c *Context) Layer(i int) ml.Context {
if backend, ok := c.layers[i]; ok {
return &Context{
b: c.b,
ctx: c.ctx,
backend: backend,
maxGraphNodes: c.maxGraphNodes,
}
}
return c
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil {
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
@ -414,7 +474,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if len(shape) < 1 || len(shape) > 4 {
panic("unsupported number of dimensions")
}
@ -428,62 +488,61 @@ func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
var t *C.struct_ggml_tensor
switch dtype {
case ml.DTypeF32:
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
panic("unsupported dtype")
}
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_set_input(t)
return &Tensor{b: ctx.b, t: t}
return &Tensor{b: c.b, t: t}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, shape)
return c.newTensor(dtype, shape)
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t := newTensor(c, dtype, shape)
t := c.newTensor(dtype, shape)
C.ggml_set_zero(t.(*Tensor).t)
return t
}
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s)
if n == 0 {
var shape C.int64_t = 0
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
return &Tensor{b: ctx.b, t: t}, nil
}
for _, v := range shape {
n /= v
}
if n != 1 {
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
return fmt.Errorf("invalid shape: %v", shape)
}
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
C.ggml_set_input(t)
return &Tensor{b: ctx.b, t: t}, nil
return nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return fromSlice(c, s, shape, C.GGML_TYPE_F32)
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
return t, nil
}
func (c Context) Close() {

View File

@ -138,17 +138,17 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}

View File

@ -72,7 +72,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
@ -82,7 +82,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
@ -92,7 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
@ -136,17 +136,17 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}