From 74bd09652d69c77a4bed34b3afda74c87295115b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Mar 2025 13:03:16 -0700 Subject: [PATCH] ml/backend/ggml: load tensors in 32KiB chunks --- ml/backend.go | 9 ++--- ml/backend/ggml/ggml.go | 65 ++++++++++++++++++++++++----------- model/model.go | 5 +-- runner/ollamarunner/runner.go | 11 +++--- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index 66eb37f78..354faf432 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -2,6 +2,7 @@ package ml import ( "bytes" + "context" "encoding/binary" "fmt" "os" @@ -80,9 +81,9 @@ type BackendParams struct { FlashAttention bool } -var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) +var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error)) -func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) { +func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) { if _, ok := backends[name]; ok { panic("backend: backend already registered") } @@ -90,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro backends[name] = f } -func NewBackend(f *os.File, params BackendParams) (Backend, error) { +func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) { if backend, ok := backends["ggml"]; ok { - return backend(f, params) + return backend(ctx, f, params) } return nil, fmt.Errorf("unsupported backend") diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 6732470ef..f6b017748 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -9,15 +9,17 @@ package ggml import "C" import ( - "errors" + "context" "fmt" "io" "log/slog" "maps" "os" + "runtime" "slices" "strconv" "strings" + "sync/atomic" "unicode" "unsafe" @@ -58,7 +60,7 @@ type Backend struct { maxGraphNodes int } -func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { +func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) { meta, n, err := fs.Decode(r, -1) if err != nil { return nil, err @@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { } } - // concurrently read in tensor data. uses a section reader which is safe for concurrent reads - sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) - var g errgroup.Group + var doneBytes atomic.Uint64 + totalBytes := uint64(n) - meta.Tensors().Offset + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) for _, t := range meta.Tensors().Items() { - for _, target := range targets[t.Name] { - g.Go(func() error { + g.Go(func() error { + tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name]))) + for i := range tts { + target := targets[t.Name][i] if target == "" { target = t.Name } @@ -312,24 +318,43 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return fmt.Errorf("unassigned tensor: %s", t.Name) } - bts := C.malloc(C.size_t(t.Size())) - if bts == nil { - return errors.New("failed to allocate tensor buffer") - } - defer C.free(bts) + tts[i] = tt + } - buf := unsafe.Slice((*byte)(bts), t.Size()) - n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf) - if err != nil || n != len(buf) { - return errors.New("read failed") + sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size())) + bts := make([]byte, 128*format.KibiByte) + + var s uint64 + for s < t.Size() { + n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))]) + if err != nil { + return err } - C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) - return nil - }) - } + for _, tt := range tts { + C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n)) + } + + s += uint64(n) + + if params.Progress != nil { + done := doneBytes.Add(uint64(n)) + params.Progress(float32(done) / float32(totalBytes)) + } + } + + return nil + }) } + // start a goroutine to cancel the errgroup if the parent context is done + go func() { + <-ctx.Done() + g.Go(func() error { + return ctx.Err() + }) + }() + if err := g.Wait(); err != nil { return nil, err } diff --git a/model/model.go b/model/model.go index ab29916ab..8355a55a8 100644 --- a/model/model.go +++ b/model/model.go @@ -1,6 +1,7 @@ package model import ( + "context" "errors" "fmt" _ "image/jpeg" @@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) { } // New initializes a new model instance with the provided configuration based on the metadata in the model file -func New(modelPath string, params ml.BackendParams) (Model, error) { +func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) { r, err := os.Open(modelPath) if err != nil { return nil, err } defer r.Close() - b, err := ml.NewBackend(r, params) + b, err := ml.NewBackend(ctx, r, params) if err != nil { return nil, err } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 67d9a1b02..31d20db80 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -678,6 +678,7 @@ func (m *multiLPath) String() string { } func (s *Server) loadModel( + ctx context.Context, mpath string, params ml.BackendParams, lpath multiLPath, @@ -687,7 +688,7 @@ func (s *Server) loadModel( multiUserCache bool, ) { var err error - s.model, err = model.New(mpath, params) + s.model, err = model.New(ctx, mpath, params) if err != nil { panic(err) } @@ -794,13 +795,13 @@ func Execute(args []string) error { } server.ready.Add(1) - go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) - - server.cond = sync.NewCond(&server.mu) - ctx, cancel := context.WithCancel(context.Background()) defer cancel() + go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) + + server.cond = sync.NewCond(&server.mu) + go server.run(ctx) addr := "127.0.0.1:" + strconv.Itoa(*port)