mirror of
https://github.com/ollama/ollama.git
synced 2025-08-25 20:51:11 +02:00
ollamarunner: Memory usage reporting
This provides granular information about the backend memory allocations required by the runner: - Per backend - Per layer - Weights, cache and graph - Allocation status This can be used for debugging and validating memory estimates.
This commit is contained in:
@@ -15,6 +15,10 @@ import (
|
||||
|
||||
type Backend interface {
|
||||
Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
@@ -68,6 +72,84 @@ type BackendParams struct {
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
type AllocationStatus int
|
||||
|
||||
const (
|
||||
// Unallocated memory - have not yet attempted to allocate
|
||||
Unallocated AllocationStatus = iota
|
||||
|
||||
// Failed memory - tried to allocate the memory and did not succeed
|
||||
Failed
|
||||
|
||||
// Allocated memory = tried and succeeded to allocate memory
|
||||
Allocated
|
||||
)
|
||||
|
||||
// Memory is the size of an allocation and whether it was successful.
|
||||
type Memory struct {
|
||||
Size uint64
|
||||
Status AllocationStatus
|
||||
}
|
||||
|
||||
func (m Memory) String() string {
|
||||
s := fmt.Sprint(m.Size)
|
||||
|
||||
switch m.Status {
|
||||
case Unallocated:
|
||||
s += "U"
|
||||
case Failed:
|
||||
s += "F"
|
||||
case Allocated:
|
||||
s += "A"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []Memory
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []Memory
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph Memory
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputsWeights are always located on the CPU and cannot be moved
|
||||
InputWeights Memory
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
@@ -102,7 +184,7 @@ type Context interface {
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
Reserve() error
|
||||
Reserve()
|
||||
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
@@ -10,7 +10,6 @@ import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -66,6 +65,12 @@ type Backend struct {
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
// requiredMemory is the cumulative memory allocations needed by the backend
|
||||
requiredMemory *ml.BackendMemory
|
||||
|
||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||
btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory
|
||||
|
||||
flashAttention bool
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
@@ -94,6 +99,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
|
||||
var requiredMemory ml.BackendMemory
|
||||
btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory)
|
||||
|
||||
type deviceBufferType struct {
|
||||
d *C.struct_ggml_backend_device
|
||||
bts []*C.struct_ggml_backend_buffer_type
|
||||
@@ -114,6 +122,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
blocks := int(meta.KV().BlockCount())
|
||||
|
||||
// create list of buffer types for the cpu
|
||||
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
|
||||
for _, d := range append(accels, append(gpus, cpus...)...) {
|
||||
@@ -121,17 +131,27 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
||||
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
|
||||
}
|
||||
}
|
||||
|
||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||
|
||||
// create list of buffer types for each gpu
|
||||
var gpuDeviceBufferTypes []deviceBufferType
|
||||
for _, d := range gpus {
|
||||
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
||||
for i, d := range gpus {
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||
d: d,
|
||||
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
|
||||
})
|
||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
|
||||
useDefaultSplit := true
|
||||
@@ -170,8 +190,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
// inputs always use cpu
|
||||
input := cpuDeviceBufferType
|
||||
|
||||
blocks := int(meta.KV().BlockCount())
|
||||
|
||||
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
||||
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
||||
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
||||
@@ -212,7 +230,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
|
||||
// contexts are shared by tensors of the same buffer type
|
||||
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
|
||||
for _, bt := range bts {
|
||||
if _, ok := ctxs[bt]; !ok {
|
||||
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
||||
@@ -238,6 +256,16 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||
if layer == -1 {
|
||||
// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
|
||||
requiredMemory.InputWeights.Status = ml.Allocated
|
||||
requiredMemory.InputWeights.Size += uint64(size)
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[layer].Size += uint64(size)
|
||||
}
|
||||
|
||||
//nolint:staticcheck // TODO: check if buffer type supports this tensor
|
||||
return tt
|
||||
}
|
||||
@@ -259,22 +287,22 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
switch {
|
||||
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, input.bts, -1)
|
||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts)
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
// these tensors should be repeated per layer
|
||||
for i, layer := range layers {
|
||||
createTensor(tensor{
|
||||
source: t,
|
||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||
}, layer.bts)
|
||||
}, layer.bts, i)
|
||||
}
|
||||
default:
|
||||
layerIndex := -1
|
||||
@@ -285,10 +313,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
if layerIndex >= 0 {
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts)
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
|
||||
} else {
|
||||
// load all other tensors on the cpu
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, input.bts, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -301,8 +329,18 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
for i := range btDeviceMemory[bt].Weights {
|
||||
if btDeviceMemory[bt].Weights[i].Size != 0 {
|
||||
if b != nil {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Failed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
|
||||
}
|
||||
|
||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||
@@ -367,7 +405,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
maxGraphNodes: maxGraphNodes,
|
||||
requiredMemory: &requiredMemory,
|
||||
btDeviceMemory: btDeviceMemory,
|
||||
maxGraphNodes: maxGraphNodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -446,6 +486,10 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) BackendMemory() ml.BackendMemory {
|
||||
return *b.requiredMemory
|
||||
}
|
||||
|
||||
func (b *Backend) Config() fs.Config {
|
||||
return b.meta.KV()
|
||||
}
|
||||
@@ -477,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
no_alloc: true,
|
||||
}),
|
||||
allocatedBuffers: &allocatedBuffers,
|
||||
layer: -1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +548,9 @@ type Context struct {
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this context
|
||||
maxGraphNodes int
|
||||
|
||||
// layer is the graph layer that this context is allocating for - assumed to be cache
|
||||
layer int
|
||||
}
|
||||
|
||||
func (c *Context) Input() ml.Context {
|
||||
@@ -513,6 +561,7 @@ func (c *Context) Input() ml.Context {
|
||||
buft: c.b.input,
|
||||
allocatedBuffers: c.allocatedBuffers,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
layer: -1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context {
|
||||
buft: buft,
|
||||
allocatedBuffers: c.allocatedBuffers,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
layer: i,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Reserve() error {
|
||||
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
return errors.New("failed to reserve graph")
|
||||
}
|
||||
func (c *Context) Reserve() {
|
||||
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
||||
|
||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||
for i := range c.b.schedBackends {
|
||||
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(size)))
|
||||
|
||||
// Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
|
||||
for _, bt := range c.b.schedBufts {
|
||||
c.b.btDeviceMemory[bt].Graph = ml.Memory{}
|
||||
}
|
||||
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
for i := range c.b.schedBackends {
|
||||
bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
|
||||
return nil
|
||||
graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
|
||||
graph.Size += uint64(bufferStatus.size)
|
||||
if bufferStatus.allocated && graph.Status != ml.Failed {
|
||||
graph.Status = ml.Allocated
|
||||
} else {
|
||||
graph.Status = ml.Failed
|
||||
}
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
}
|
||||
|
||||
if !reserved {
|
||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) MaxGraphNodes() int {
|
||||
@@ -599,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
if c.buft == nil {
|
||||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
@@ -622,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -635,31 +697,34 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
|
||||
}
|
||||
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
||||
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if c.layer >= 0 {
|
||||
cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]
|
||||
|
||||
cache.Size += uint64(size)
|
||||
if b != nil {
|
||||
cache.Status = ml.Allocated
|
||||
} else {
|
||||
cache.Status = ml.Failed
|
||||
}
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||
}
|
||||
|
||||
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
return &Tensor{b: c.b, t: t}, nil
|
||||
return &Tensor{b: c.b, t: t}
|
||||
}
|
||||
|
||||
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
return c.newTensor(dtype, shape)
|
||||
}
|
||||
|
||||
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t := c.newTensor(dtype, shape)
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
return t
|
||||
}
|
||||
@@ -687,10 +752,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := c.newTensor(ml.DTypeF32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
@@ -704,10 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := c.newTensor(ml.DTypeI32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
|
Reference in New Issue
Block a user