ollamarunner: Memory usage reporting

This provides granular information about the backend memory allocations
required by the runner:
 - Per backend
 - Per layer
 - Weights, cache and graph
 - Allocation status

This can be used for debugging and validating memory estimates.
This commit is contained in:
Jesse Gross
2025-04-17 11:00:25 -07:00
committed by Jesse Gross
parent 6db8a3771c
commit 73d6a82cce
5 changed files with 224 additions and 78 deletions

View File

@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil } func (c *testContext) Reserve() {}
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10

View File

@@ -15,6 +15,10 @@ import (
type Backend interface { type Backend interface {
Load(ctx context.Context, progress func(float32)) error Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
BackendMemory() BackendMemory
Config() fs.Config Config() fs.Config
Get(name string) Tensor Get(name string) Tensor
NewContext() Context NewContext() Context
@@ -68,6 +72,84 @@ type BackendParams struct {
FlashAttention bool FlashAttention bool
} }
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
type AllocationStatus int
const (
// Unallocated memory - have not yet attempted to allocate
Unallocated AllocationStatus = iota
// Failed memory - tried to allocate the memory and did not succeed
Failed
// Allocated memory = tried and succeeded to allocate memory
Allocated
)
// Memory is the size of an allocation and whether it was successful.
type Memory struct {
Size uint64
Status AllocationStatus
}
func (m Memory) String() string {
s := fmt.Sprint(m.Size)
switch m.Status {
case Unallocated:
s += "U"
case Failed:
s += "F"
case Allocated:
s += "A"
}
return s
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []Memory
// Cache is the per-layer memory needed for the KV cache.
Cache []Memory
// Graph is the size of the compute graph. It is not per-layer.
Graph Memory
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputsWeights are always located on the CPU and cannot be moved
InputWeights Memory
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
var backends = make(map[string]func(string, BackendParams) (Backend, error)) var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
@@ -102,7 +184,7 @@ type Context interface {
// graph, simply preallocates memory. Typically called with a // graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for // worst case graph to ensure all resources are available for
// for future inference. // for future inference.
Reserve() error Reserve()
MaxGraphNodes() int MaxGraphNodes() int
Close() Close()

View File

@@ -10,7 +10,6 @@ import "C"
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -66,6 +65,12 @@ type Backend struct {
// layers is the backend used for repeating layers // layers is the backend used for repeating layers
layers map[int]*C.struct_ggml_backend_buffer_type layers map[int]*C.struct_ggml_backend_buffer_type
// requiredMemory is the cumulative memory allocations needed by the backend
requiredMemory *ml.BackendMemory
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory
flashAttention bool flashAttention bool
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler // maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
@@ -94,6 +99,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"num_key_values", len(meta.KV()), "num_key_values", len(meta.KV()),
) )
var requiredMemory ml.BackendMemory
btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory)
type deviceBufferType struct { type deviceBufferType struct {
d *C.struct_ggml_backend_device d *C.struct_ggml_backend_device
bts []*C.struct_ggml_backend_buffer_type bts []*C.struct_ggml_backend_buffer_type
@@ -114,6 +122,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
} }
} }
blocks := int(meta.KV().BlockCount())
// create list of buffer types for the cpu // create list of buffer types for the cpu
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)} cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
for _, d := range append(accels, append(gpus, cpus...)...) { for _, d := range append(accels, append(gpus, cpus...)...) {
@@ -121,17 +131,27 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU, case C.GGML_BACKEND_DEVICE_TYPE_CPU,
C.GGML_BACKEND_DEVICE_TYPE_ACCEL: C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
} }
} }
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
// create list of buffer types for each gpu // create list of buffer types for each gpu
var gpuDeviceBufferTypes []deviceBufferType var gpuDeviceBufferTypes []deviceBufferType
for _, d := range gpus { requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
for i, d := range gpus {
bt := C.ggml_backend_dev_buffer_type(d) bt := C.ggml_backend_dev_buffer_type(d)
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
d: d, d: d,
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...), bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
}) })
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
} }
useDefaultSplit := true useDefaultSplit := true
@@ -170,8 +190,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
// inputs always use cpu // inputs always use cpu
input := cpuDeviceBufferType input := cpuDeviceBufferType
blocks := int(meta.KV().BlockCount())
// define a range of gpu layers. anything outside of this range is assigned to the cpu // define a range of gpu layers. anything outside of this range is assigned to the cpu
gpuRangeStart := max(0, blocks-params.NumGPULayers) gpuRangeStart := max(0, blocks-params.NumGPULayers)
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
@@ -212,7 +230,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
// contexts are shared by tensors of the same buffer type // contexts are shared by tensors of the same buffer type
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context) ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor { createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
for _, bt := range bts { for _, bt := range bts {
if _, ok := ctxs[bt]; !ok { if _, ok := ctxs[bt]; !ok {
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{ ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
@@ -238,6 +256,16 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
C.ggml_set_name(tt, cname) C.ggml_set_name(tt, cname)
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
if layer == -1 {
// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
requiredMemory.InputWeights.Status = ml.Allocated
requiredMemory.InputWeights.Size += uint64(size)
} else {
btDeviceMemory[bt].Weights[layer].Size += uint64(size)
}
//nolint:staticcheck // TODO: check if buffer type supports this tensor //nolint:staticcheck // TODO: check if buffer type supports this tensor
return tt return tt
} }
@@ -259,22 +287,22 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
switch { switch {
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"): case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
createTensor(tensor{source: t}, input.bts) createTensor(tensor{source: t}, input.bts, -1)
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" { if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
createTensor(tensor{source: t, target: "output.weight"}, output.bts) createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
} }
case contains(t.Name, "cls", "output", "output_norm"): case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts) createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."): case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible // TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, output.bts) createTensor(tensor{source: t}, output.bts, blocks)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"): case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
// these tensors should be repeated per layer // these tensors should be repeated per layer
for i, layer := range layers { for i, layer := range layers {
createTensor(tensor{ createTensor(tensor{
source: t, source: t,
target: "blk." + strconv.Itoa(i) + "." + t.Name, target: "blk." + strconv.Itoa(i) + "." + t.Name,
}, layer.bts) }, layer.bts, i)
} }
default: default:
layerIndex := -1 layerIndex := -1
@@ -285,10 +313,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
} }
if layerIndex >= 0 { if layerIndex >= 0 {
createTensor(tensor{source: t}, layers[layerIndex].bts) createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
} else { } else {
// load all other tensors on the cpu // load all other tensors on the cpu
createTensor(tensor{source: t}, input.bts) createTensor(tensor{source: t}, input.bts, -1)
} }
} }
} }
@@ -301,8 +329,18 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
} }
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
for i := range btDeviceMemory[bt].Weights {
if btDeviceMemory[bt].Weights[i].Size != 0 {
if b != nil {
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
} else {
btDeviceMemory[bt].Weights[i].Status = ml.Failed
}
}
}
if b == nil { if b == nil {
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt))) panic(ml.ErrNoMem{BackendMemory: requiredMemory})
} }
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
@@ -367,7 +405,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
} }
return m return m
}(), }(),
maxGraphNodes: maxGraphNodes, requiredMemory: &requiredMemory,
btDeviceMemory: btDeviceMemory,
maxGraphNodes: maxGraphNodes,
}, nil }, nil
} }
@@ -446,6 +486,10 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
return nil return nil
} }
func (b *Backend) BackendMemory() ml.BackendMemory {
return *b.requiredMemory
}
func (b *Backend) Config() fs.Config { func (b *Backend) Config() fs.Config {
return b.meta.KV() return b.meta.KV()
} }
@@ -477,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
no_alloc: true, no_alloc: true,
}), }),
allocatedBuffers: &allocatedBuffers, allocatedBuffers: &allocatedBuffers,
layer: -1,
} }
} }
@@ -503,6 +548,9 @@ type Context struct {
// maxGraphNodes is the maximum allowed number of graph nodes in this context // maxGraphNodes is the maximum allowed number of graph nodes in this context
maxGraphNodes int maxGraphNodes int
// layer is the graph layer that this context is allocating for - assumed to be cache
layer int
} }
func (c *Context) Input() ml.Context { func (c *Context) Input() ml.Context {
@@ -513,6 +561,7 @@ func (c *Context) Input() ml.Context {
buft: c.b.input, buft: c.b.input,
allocatedBuffers: c.allocatedBuffers, allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes, maxGraphNodes: c.maxGraphNodes,
layer: -1,
} }
} }
@@ -527,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context {
buft: buft, buft: buft,
allocatedBuffers: c.allocatedBuffers, allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes, maxGraphNodes: c.maxGraphNodes,
layer: i,
} }
} }
@@ -564,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
} }
} }
func (c *Context) Reserve() error { func (c *Context) Reserve() {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
}
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
for i := range c.b.schedBackends {
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), for _, bt := range c.b.schedBufts {
"size", format.HumanBytes2(uint64(size))) c.b.btDeviceMemory[bt].Graph = ml.Memory{}
} }
C.ggml_backend_sched_reset(c.b.sched) for i := range c.b.schedBackends {
bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])
return nil graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
graph.Size += uint64(bufferStatus.size)
if bufferStatus.allocated && graph.Status != ml.Failed {
graph.Status = ml.Allocated
} else {
graph.Status = ml.Failed
}
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
"size", format.HumanBytes2(uint64(bufferStatus.size)))
}
if !reserved {
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
}
} }
func (c *Context) MaxGraphNodes() int { func (c *Context) MaxGraphNodes() int {
@@ -599,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad return ((length + pad - 1) / pad) * pad
} }
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if c.buft == nil { if c.buft == nil {
panic("set Input or Layer before creating tensors") panic("set Input or Layer before creating tensors")
} }
@@ -622,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if len(shape) < 1 || shape[0] == 0 { if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0 var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 { } else if len(shape) > 4 {
panic("unsupported number of dimensions") panic("unsupported number of dimensions")
} }
@@ -635,31 +697,34 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
if b == nil {
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
}
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
if c.layer >= 0 {
cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]
cache.Size += uint64(size)
if b != nil {
cache.Status = ml.Allocated
} else {
cache.Status = ml.Failed
}
}
if b == nil {
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
}
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}, nil return &Tensor{b: c.b, t: t}
} }
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape) return c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
return t
} }
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape) t := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
C.ggml_set_zero(t.(*Tensor).t) C.ggml_set_zero(t.(*Tensor).t)
return t return t
} }
@@ -687,10 +752,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return nil, err return nil, err
} }
t, err := c.newTensor(ml.DTypeF32, shape) t := c.newTensor(ml.DTypeF32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 { if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
@@ -704,10 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return nil, err return nil, err
} }
t, err := c.newTensor(ml.DTypeI32, shape) t := c.newTensor(ml.DTypeI32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 { if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))

View File

@@ -95,10 +95,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
} }
} }
} else { } else {
err := computeCtx.Reserve() computeCtx.Reserve()
if err != nil {
return nil, err
}
} }
} }

View File

@@ -826,16 +826,12 @@ func (s *Server) reserveWorstCaseGraph() error {
return err return err
} }
err = ctx.Forward(t).Reserve() ctx.Forward(t).Reserve()
if err != nil {
return err
}
return nil return nil
} }
func (s *Server) loadModel( func (s *Server) initModel(
ctx context.Context,
mpath string, mpath string,
params ml.BackendParams, params ml.BackendParams,
lpath multiLPath, lpath multiLPath,
@@ -843,21 +839,21 @@ func (s *Server) loadModel(
kvCacheType string, kvCacheType string,
kvSize int, kvSize int,
multiUserCache bool, multiUserCache bool,
) { ) error {
var err error var err error
s.model, err = model.New(mpath, params) s.model, err = model.New(mpath, params)
if err != nil { if err != nil {
panic(err) return err
} }
// TODO(jessegross): LoRA loading // TODO(jessegross): LoRA loading
if lpath.String() != "" { if lpath.String() != "" {
panic("loras are not yet implemented") return errors.New("loras are not yet implemented")
} }
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil { if err != nil {
panic(err) return err
} }
if !s.cache.enabled && parallel > 1 { if !s.cache.enabled && parallel > 1 {
@@ -869,11 +865,25 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph() return s.reserveWorstCaseGraph()
}
func (s *Server) load(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
parallel int,
kvCacheType string,
kvSize int,
multiUserCache bool) {
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
if err != nil { if err != nil {
panic(err) panic(err)
} }
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
err = s.model.Backend().Load(ctx, err = s.model.Backend().Load(ctx,
func(progress float32) { func(progress float32) {
s.progress = progress s.progress = progress
@@ -921,9 +931,14 @@ func Execute(args []string) error {
status: llm.ServerStatusLoadingModel, status: llm.ServerStatusLoadingModel,
} }
server.cond = sync.NewCond(&server.mu)
server.ready.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:
// no-mmap // no-mmap
// mlock
var tensorSplitFloats []float32 var tensorSplitFloats []float32
if *tensorSplit != "" { if *tensorSplit != "" {
@@ -943,14 +958,7 @@ func Execute(args []string) error {
FlashAttention: *flashAttention, FlashAttention: *flashAttention,
} }
server.ready.Add(1) go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx) go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port) addr := "127.0.0.1:" + strconv.Itoa(*port)