Merge pull request #9897 from ollama/mxyng/chunk-load

ml/backend/ggml: load tensors in 128KiB chunks
This commit is contained in:
Michael Yang 2025-03-21 14:47:13 -07:00 committed by GitHub
commit 4b34930a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 31 deletions

View File

@ -2,6 +2,7 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@ -80,9 +81,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@ -90,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f
}
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(f, params)
return backend(ctx, f, params)
}
return nil, fmt.Errorf("unsupported backend")

View File

@ -9,15 +9,17 @@ package ggml
import "C"
import (
"errors"
"context"
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
"unicode"
"unsafe"
@ -58,7 +60,7 @@ type Backend struct {
maxGraphNodes int
}
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
}
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] {
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" {
target = t.Name
}
@ -312,23 +318,42 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
bts := C.malloc(C.size_t(t.Size()))
if bts == nil {
return errors.New("failed to allocate tensor buffer")
}
defer C.free(bts)
buf := unsafe.Slice((*byte)(bts), t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
if err != nil || n != len(buf) {
return errors.New("read failed")
tts[i] = tt
}
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
return nil
})
}
}
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r, params)
b, err := ml.NewBackend(ctx, r, params)
if err != nil {
return nil, err
}

View File

@ -678,6 +678,7 @@ func (m *multiLPath) String() string {
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@ -687,7 +688,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath, params)
s.model, err = model.New(ctx, mpath, params)
if err != nil {
panic(err)
}
@ -794,13 +795,13 @@ func Execute(args []string) error {
}
server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)