ggml: Support closing backends

In order to iteratively find the best memory allocation, we need to
be able to free backend memory so we can try again.
This commit is contained in:
Jesse Gross
2025-04-17 17:12:01 -07:00
committed by Jesse Gross
parent d7f4f788d1
commit 756c78cfc7
4 changed files with 71 additions and 24 deletions

View File

@@ -15,6 +15,9 @@ import (
)
type Backend interface {
// Close frees all memory associated with this backend
Close()
Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model

View File

@@ -19,6 +19,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"unicode"
"unsafe"
@@ -33,15 +34,33 @@ import (
"golang.org/x/sync/errgroup"
)
func devices() []C.ggml_backend_dev_t {
ggml.OnceLoad()
ds := make([]C.ggml_backend_dev_t, C.ggml_backend_dev_count())
for i := range ds {
ds[i] = C.ggml_backend_dev_get(C.size_t(i))
}
var (
cpus, accels, gpus []C.ggml_backend_dev_t
backends map[C.ggml_backend_dev_t]C.ggml_backend_t
)
return ds
}
var initDevices = sync.OnceFunc(func() {
ggml.OnceLoad()
backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t)
for i := range C.ggml_backend_dev_count() {
d := C.ggml_backend_dev_get(i)
switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
cpus = append(cpus, d)
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
backends[d] = C.ggml_backend_dev_init(d, nil)
}
})
type Backend struct {
// modelPath is the location of the model data
@@ -75,6 +94,9 @@ type Backend struct {
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
// weightBuffers are the GGML contexts and buffers for allocating weights
weightBuffers map[*C.struct_ggml_context]C.ggml_backend_buffer_t
}
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
@@ -99,6 +121,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"num_key_values", len(meta.KV()),
)
initDevices()
var requiredMemory ml.BackendMemory
btDeviceMemory := make(map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory)
@@ -107,21 +131,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
bts []C.ggml_backend_buffer_type_t
}
var cpus, accels, gpus []C.ggml_backend_dev_t
for _, d := range devices() {
switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
cpus = append(cpus, d)
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
}
blocks := int(meta.KV().BlockCount())
// create list of buffer types for the cpu
@@ -348,6 +357,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
if b == nil {
for _, b := range bbs {
C.ggml_backend_buffer_free(b)
}
for _, ctx := range ctxs {
C.ggml_free(ctx)
}
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
}
@@ -394,7 +411,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
var schedBackends []C.ggml_backend_t
var schedBufts []C.ggml_backend_buffer_type_t
for _, d := range append(gpus, append(accels, cpus...)...) {
b := C.ggml_backend_dev_init(d, nil)
b := backends[d]
bt := C.ggml_backend_get_default_buffer_type(b)
deviceBufferTypes[d] = bt
@@ -436,6 +453,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory: &requiredMemory,
btDeviceMemory: btDeviceMemory,
maxGraphNodes: maxGraphNodes,
weightBuffers: bbs,
}, nil
}
@@ -443,6 +461,19 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Close() {
if b == nil {
return
}
for ctx, b := range b.weightBuffers {
C.ggml_backend_buffer_free(b)
C.ggml_free(ctx)
}
C.ggml_backend_sched_free(b.sched)
}
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
var doneBytes atomic.Uint64
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset