From 73d6a82cce18f84ff5c67148783224cf25b30b32 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 17 Apr 2025 11:00:25 -0700 Subject: [PATCH] ollamarunner: Memory usage reporting This provides granular information about the backend memory allocations required by the runner: - Per backend - Per layer - Weights, cache and graph - Allocation status This can be used for debugging and validating memory estimates. --- kvcache/causal_test.go | 2 +- ml/backend.go | 84 ++++++++++++++- ml/backend/ggml/ggml.go | 163 ++++++++++++++++++++---------- runner/ollamarunner/multimodal.go | 5 +- runner/ollamarunner/runner.go | 48 +++++---- 5 files changed, 224 insertions(+), 78 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 796987088b..820d496d1d 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} -func (c *testContext) Reserve() error { return nil } +func (c *testContext) Reserve() {} func (c *testContext) MaxGraphNodes() int { return 10 diff --git a/ml/backend.go b/ml/backend.go index 3c417ef9de..7c9b9e3139 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -15,6 +15,10 @@ import ( type Backend interface { Load(ctx context.Context, progress func(float32)) error + + // BackendMemory returns the memory allocations that were made for this model + BackendMemory() BackendMemory + Config() fs.Config Get(name string) Tensor NewContext() Context @@ -68,6 +72,84 @@ type BackendParams struct { FlashAttention bool } +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} + +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +type AllocationStatus int + +const ( + // Unallocated memory - have not yet attempted to allocate + Unallocated AllocationStatus = iota + + // Failed memory - tried to allocate the memory and did not succeed + Failed + + // Allocated memory = tried and succeeded to allocate memory + Allocated +) + +// Memory is the size of an allocation and whether it was successful. +type Memory struct { + Size uint64 + Status AllocationStatus +} + +func (m Memory) String() string { + s := fmt.Sprint(m.Size) + + switch m.Status { + case Unallocated: + s += "U" + case Failed: + s += "F" + case Allocated: + s += "A" + } + + return s +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []Memory + + // Cache is the per-layer memory needed for the KV cache. + Cache []Memory + + // Graph is the size of the compute graph. It is not per-layer. + Graph Memory +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputsWeights are always located on the CPU and cannot be moved + InputWeights Memory + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + var backends = make(map[string]func(string, BackendParams) (Backend, error)) func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { @@ -102,7 +184,7 @@ type Context interface { // graph, simply preallocates memory. Typically called with a // worst case graph to ensure all resources are available for // for future inference. - Reserve() error + Reserve() MaxGraphNodes() int Close() diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 44e3b61bc5..496ba8a607 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,7 +10,6 @@ import "C" import ( "context" - "errors" "fmt" "io" "log/slog" @@ -66,6 +65,12 @@ type Backend struct { // layers is the backend used for repeating layers layers map[int]*C.struct_ggml_backend_buffer_type + // requiredMemory is the cumulative memory allocations needed by the backend + requiredMemory *ml.BackendMemory + + // btDeviceMemory maps from a buffer type to the memory allocations associated with that device + btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory + flashAttention bool // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler @@ -94,6 +99,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { "num_key_values", len(meta.KV()), ) + var requiredMemory ml.BackendMemory + btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory) + type deviceBufferType struct { d *C.struct_ggml_backend_device bts []*C.struct_ggml_backend_buffer_type @@ -114,6 +122,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } + blocks := int(meta.KV().BlockCount()) + // create list of buffer types for the cpu cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)} for _, d := range append(accels, append(gpus, cpus...)...) { @@ -121,17 +131,27 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { case C.GGML_BACKEND_DEVICE_TYPE_CPU, C.GGML_BACKEND_DEVICE_TYPE_ACCEL: cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) + btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } } + requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d)) + requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) + requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) + // create list of buffer types for each gpu var gpuDeviceBufferTypes []deviceBufferType - for _, d := range gpus { + requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus)) + for i, d := range gpus { bt := C.ggml_backend_dev_buffer_type(d) gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ d: d, bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...), }) + btDeviceMemory[bt] = &requiredMemory.GPUs[i] + requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) + requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) + requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) } useDefaultSplit := true @@ -170,8 +190,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { // inputs always use cpu input := cpuDeviceBufferType - blocks := int(meta.KV().BlockCount()) - // define a range of gpu layers. anything outside of this range is assigned to the cpu gpuRangeStart := max(0, blocks-params.NumGPULayers) gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) @@ -212,7 +230,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { // contexts are shared by tensors of the same buffer type ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) - createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { + createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor { for _, bt := range bts { if _, ok := ctxs[bt]; !ok { ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ @@ -238,6 +256,16 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { C.ggml_set_name(tt, cname) slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) + + size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) + if layer == -1 { + // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case + requiredMemory.InputWeights.Status = ml.Allocated + requiredMemory.InputWeights.Size += uint64(size) + } else { + btDeviceMemory[bt].Weights[layer].Size += uint64(size) + } + //nolint:staticcheck // TODO: check if buffer type supports this tensor return tt } @@ -259,22 +287,22 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { for _, t := range meta.Tensors().Items() { switch { case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" { - createTensor(tensor{source: t, target: "output.weight"}, output.bts) + createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks) } case contains(t.Name, "cls", "output", "output_norm"): - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): // TODO: assign vision tensors to the gpu if possible - createTensor(tensor{source: t}, output.bts) + createTensor(tensor{source: t}, output.bts, blocks) case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): // these tensors should be repeated per layer for i, layer := range layers { createTensor(tensor{ source: t, target: "blk." + strconv.Itoa(i) + "." + t.Name, - }, layer.bts) + }, layer.bts, i) } default: layerIndex := -1 @@ -285,10 +313,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } if layerIndex >= 0 { - createTensor(tensor{source: t}, layers[layerIndex].bts) + createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex) } else { // load all other tensors on the cpu - createTensor(tensor{source: t}, input.bts) + createTensor(tensor{source: t}, input.bts, -1) } } } @@ -301,8 +329,18 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + for i := range btDeviceMemory[bt].Weights { + if btDeviceMemory[bt].Weights[i].Size != 0 { + if b != nil { + btDeviceMemory[bt].Weights[i].Status = ml.Allocated + } else { + btDeviceMemory[bt].Weights[i].Status = ml.Failed + } + } + } + if b == nil { - return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt))) + panic(ml.ErrNoMem{BackendMemory: requiredMemory}) } C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) @@ -367,7 +405,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } return m }(), - maxGraphNodes: maxGraphNodes, + requiredMemory: &requiredMemory, + btDeviceMemory: btDeviceMemory, + maxGraphNodes: maxGraphNodes, }, nil } @@ -446,6 +486,10 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { return nil } +func (b *Backend) BackendMemory() ml.BackendMemory { + return *b.requiredMemory +} + func (b *Backend) Config() fs.Config { return b.meta.KV() } @@ -477,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { no_alloc: true, }), allocatedBuffers: &allocatedBuffers, + layer: -1, } } @@ -503,6 +548,9 @@ type Context struct { // maxGraphNodes is the maximum allowed number of graph nodes in this context maxGraphNodes int + + // layer is the graph layer that this context is allocating for - assumed to be cache + layer int } func (c *Context) Input() ml.Context { @@ -513,6 +561,7 @@ func (c *Context) Input() ml.Context { buft: c.b.input, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: -1, } } @@ -527,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context { buft: buft, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, + layer: i, } } @@ -564,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } -func (c *Context) Reserve() error { - if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { - C.ggml_backend_sched_reset(c.b.sched) - return errors.New("failed to reserve graph") - } +func (c *Context) Reserve() { + reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph) slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) - for i := range c.b.schedBackends { - size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), - "size", format.HumanBytes2(uint64(size))) + + // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations + for _, bt := range c.b.schedBufts { + c.b.btDeviceMemory[bt].Graph = ml.Memory{} } - C.ggml_backend_sched_reset(c.b.sched) + for i := range c.b.schedBackends { + bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) - return nil + graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph + graph.Size += uint64(bufferStatus.size) + if bufferStatus.allocated && graph.Status != ml.Failed { + graph.Status = ml.Allocated + } else { + graph.Status = ml.Failed + } + + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), + "size", format.HumanBytes2(uint64(bufferStatus.size))) + } + + if !reserved { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } } func (c *Context) MaxGraphNodes() int { @@ -599,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { +func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { if c.buft == nil { panic("set Input or Layer before creating tensors") } @@ -622,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { if len(shape) < 1 || shape[0] == 0 { var shape C.int64_t = 0 - return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil + return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} } else if len(shape) > 4 { panic("unsupported number of dimensions") } @@ -635,31 +697,34 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) - b := C.ggml_backend_buft_alloc_buffer(c.buft, size) - if b == nil { - return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft))) - } - *c.allocatedBuffers = append(*c.allocatedBuffers, b) + b := C.ggml_backend_buft_alloc_buffer(c.buft, size) + if c.layer >= 0 { + cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] + + cache.Size += uint64(size) + if b != nil { + cache.Status = ml.Allocated + } else { + cache.Status = ml.Failed + } + } + + if b == nil { + panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) + } + + *c.allocatedBuffers = append(*c.allocatedBuffers, b) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - return &Tensor{b: c.b, t: t}, nil + return &Tensor{b: c.b, t: t} } func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - - return t + return c.newTensor(dtype, shape) } func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { - t, err := c.newTensor(dtype, shape) - if err != nil { - panic(err) - } - + t := c.newTensor(dtype, shape) C.ggml_set_zero(t.(*Tensor).t) return t } @@ -687,10 +752,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { return nil, err } - t, err := c.newTensor(ml.DTypeF32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeF32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) @@ -704,10 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return nil, err } - t, err := c.newTensor(ml.DTypeI32, shape) - if err != nil { - return nil, err - } + t := c.newTensor(ml.DTypeI32, shape) if len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index d78612feed..dbe6bba106 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -95,10 +95,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten } } } else { - err := computeCtx.Reserve() - if err != nil { - return nil, err - } + computeCtx.Reserve() } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a488a104b5..99bee1061a 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -826,16 +826,12 @@ func (s *Server) reserveWorstCaseGraph() error { return err } - err = ctx.Forward(t).Reserve() - if err != nil { - return err - } + ctx.Forward(t).Reserve() return nil } -func (s *Server) loadModel( - ctx context.Context, +func (s *Server) initModel( mpath string, params ml.BackendParams, lpath multiLPath, @@ -843,21 +839,21 @@ func (s *Server) loadModel( kvCacheType string, kvSize int, multiUserCache bool, -) { +) error { var err error s.model, err = model.New(mpath, params) if err != nil { - panic(err) + return err } // TODO(jessegross): LoRA loading if lpath.String() != "" { - panic("loras are not yet implemented") + return errors.New("loras are not yet implemented") } s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { - panic(err) + return err } if !s.cache.enabled && parallel > 1 { @@ -869,11 +865,25 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - err = s.reserveWorstCaseGraph() + return s.reserveWorstCaseGraph() +} + +func (s *Server) load( + ctx context.Context, + mpath string, + params ml.BackendParams, + lpath multiLPath, + parallel int, + kvCacheType string, + kvSize int, + multiUserCache bool) { + err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) if err != nil { panic(err) } + slog.Debug("memory", "allocated", s.model.Backend().BackendMemory()) + err = s.model.Backend().Load(ctx, func(progress float32) { s.progress = progress @@ -921,9 +931,14 @@ func Execute(args []string) error { status: llm.ServerStatusLoadingModel, } + server.cond = sync.NewCond(&server.mu) + server.ready.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // TODO(jessegross): Parameters that need to be implemented: // no-mmap - // mlock var tensorSplitFloats []float32 if *tensorSplit != "" { @@ -943,14 +958,7 @@ func Execute(args []string) error { FlashAttention: *flashAttention, } - server.ready.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) - - server.cond = sync.NewCond(&server.mu) - + go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) go server.run(ctx) addr := "127.0.0.1:" + strconv.Itoa(*port)