Merge pull request #9897 from ollama/mxyng/chunk-load

ml/backend/ggml: load tensors in 128KiB chunks
This commit is contained in:
Michael Yang 2025-03-21 14:47:13 -07:00 committed by GitHub
commit 4b34930a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 31 deletions

View File

@ -2,6 +2,7 @@ package ml
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
@ -80,9 +81,9 @@ type BackendParams struct {
FlashAttention bool FlashAttention bool
} }
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) { func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok { if _, ok := backends[name]; ok {
panic("backend: backend already registered") panic("backend: backend already registered")
} }
@ -90,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f backends[name] = f
} }
func NewBackend(f *os.File, params BackendParams) (Backend, error) { func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok { if backend, ok := backends["ggml"]; ok {
return backend(f, params) return backend(ctx, f, params)
} }
return nil, fmt.Errorf("unsupported backend") return nil, fmt.Errorf("unsupported backend")

View File

@ -9,15 +9,17 @@ package ggml
import "C" import "C"
import ( import (
"errors" "context"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps" "maps"
"os" "os"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"unicode" "unicode"
"unsafe" "unsafe"
@ -58,7 +60,7 @@ type Backend struct {
maxGraphNodes int maxGraphNodes int
} }
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1) meta, n, err := fs.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
} }
} }
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads var doneBytes atomic.Uint64
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) totalBytes := uint64(n) - meta.Tensors().Offset
var g errgroup.Group
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() { for _, t := range meta.Tensors().Items() {
for _, target := range targets[t.Name] { g.Go(func() error {
g.Go(func() error { tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" { if target == "" {
target = t.Name target = t.Name
} }
@ -312,24 +318,43 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name) return fmt.Errorf("unassigned tensor: %s", t.Name)
} }
bts := C.malloc(C.size_t(t.Size())) tts[i] = tt
if bts == nil { }
return errors.New("failed to allocate tensor buffer")
}
defer C.free(bts)
buf := unsafe.Slice((*byte)(bts), t.Size()) sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf) bts := make([]byte, 128*format.KibiByte)
if err != nil || n != len(buf) {
return errors.New("read failed") var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
return err
} }
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) for _, tt := range tts {
return nil C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}) }
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
} }
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
} }
// New initializes a new model instance with the provided configuration based on the metadata in the model file // New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) { func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath) r, err := os.Open(modelPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
b, err := ml.NewBackend(r, params) b, err := ml.NewBackend(ctx, r, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -678,6 +678,7 @@ func (m *multiLPath) String() string {
} }
func (s *Server) loadModel( func (s *Server) loadModel(
ctx context.Context,
mpath string, mpath string,
params ml.BackendParams, params ml.BackendParams,
lpath multiLPath, lpath multiLPath,
@ -687,7 +688,7 @@ func (s *Server) loadModel(
multiUserCache bool, multiUserCache bool,
) { ) {
var err error var err error
s.model, err = model.New(mpath, params) s.model, err = model.New(ctx, mpath, params)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -794,13 +795,13 @@ func Execute(args []string) error {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx) go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port) addr := "127.0.0.1:" + strconv.Itoa(*port)