mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
cmd: compact pull progress and make client2 the default
This makes client2 the default for the pull command, and updates the progress output to be more compact and easier to read, omitting details relevant only to the original client implementation. Previously, the pull command would show a progress bar for each layer like this: ; ollama pull gemma3 pulling manifest pulling aeda25e63ebd... 100% ▕██████████████████████████ 3.3 GB pulling e0a42594d802... 100% ▕██████████████████████████ 358 B pulling dd084c7d92a3... 100% ▕██████████████████████████ 8.4 KB pulling 3116c5225075... 100% ▕██████████████████████████ 77 B pulling b6ae5839783f... 100% ▕██████████████████████████ 489 B verifying sha256 digest writing manifest success This changes the output to look like this: ; ollama pull gemma3 Downloading gemma3 123M/4.5G (2.7%) As a side note, the progress bar is not safe for concurrent use and so the pull command had race data races. This should fix any data races in the pull command, but it does not affect any other commands using the progress package. Also, move the call to Chunker.Close into a defer statement to ensure it it is called. Previously, care needed to be taken to ensure there we no panics, but without 100% control because we called t.update which is user-defined and could panic.
This commit is contained in:
parent
c001b98087
commit
b7f1a395ea
117
cmd/cmd.go
117
cmd/cmd.go
@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"maps"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -773,57 +775,100 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := args[0] // cobra should be cofigured to always pass 1 arg
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
errc := make(chan error, 1)
|
||||
var mu sync.Mutex
|
||||
progressLocked := make(map[string][2]int64) // digest -> [completed, total]
|
||||
go func() {
|
||||
p := &api.PullRequest{Name: name}
|
||||
errc <- client.Pull(ctx, p, func(up api.ProgressResponse) error {
|
||||
if up.Digest == "" && up.Status != "" {
|
||||
// A status with no digest is a terminal
|
||||
// status. Give up and return an error.
|
||||
//
|
||||
// But first: Strip any "error: " prefix so it
|
||||
// does not stutter with Cobra's "Error: "
|
||||
// prefix.
|
||||
//
|
||||
// Our future client will handle the stream
|
||||
// updates and errors properly. This works for
|
||||
// now though.
|
||||
status := strings.TrimPrefix(up.Status, "error: ")
|
||||
return errors.New(status)
|
||||
}
|
||||
mu.Lock()
|
||||
progressLocked[up.Digest] = [2]int64{up.Completed, up.Total}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
fmt.Fprint(os.Stderr, escHideCursor)
|
||||
defer fmt.Fprint(os.Stderr, escShowCursor)
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
progress := make(map[string][2]int64)
|
||||
flushProgress := func() {
|
||||
mu.Lock()
|
||||
maps.Copy(progress, progressLocked)
|
||||
mu.Unlock()
|
||||
var completed, total int64
|
||||
for _, v := range progress {
|
||||
completed += v[0]
|
||||
total += v[1]
|
||||
}
|
||||
if total > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\rDownloading %s %s/%s (%0.1f%%)%s",
|
||||
name,
|
||||
formatNatural(completed),
|
||||
formatNatural(total),
|
||||
100*float64(completed)/float64(total),
|
||||
escClearToEOL,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
flushProgress()
|
||||
case err := <-errc:
|
||||
flushProgress()
|
||||
fmt.Fprintln(os.Stderr)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
const (
|
||||
escClearToEOL = "\x1b[K"
|
||||
escHideCursor = "\x1b[?25l"
|
||||
escShowCursor = "\x1b[?25h"
|
||||
)
|
||||
|
||||
// formatNatural formats a number of bytes into SI units. This aligns with
|
||||
// other downloaders like cURL.
|
||||
func formatNatural(n int64) string {
|
||||
switch {
|
||||
case n < 1024:
|
||||
return fmt.Sprintf("%d", n)
|
||||
case n < 1024*1024:
|
||||
return fmt.Sprintf("%.1fk", float64(n)/1024)
|
||||
case n < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.1fM", float64(n)/(1024*1024))
|
||||
default:
|
||||
return fmt.Sprintf("%.1fG", float64(n)/(1024*1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
@ -107,15 +107,19 @@ func DefaultCache() (*blob.DiskCache, error) {
|
||||
//
|
||||
// In both cases, the code field is optional and may be empty.
|
||||
type Error struct {
|
||||
Status int `json:"-"` // TODO(bmizerany): remove this
|
||||
status int `json:"-"` // TODO(bmizerany): remove this
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.status > 500
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("registry responded with status ")
|
||||
b.WriteString(strconv.Itoa(e.Status))
|
||||
b.WriteString(strconv.Itoa(e.status))
|
||||
if e.Code != "" {
|
||||
b.WriteString(": code ")
|
||||
b.WriteString(e.Code)
|
||||
@ -129,7 +133,7 @@ func (e *Error) Error() string {
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Int("status", e.Status),
|
||||
slog.Int("status", e.status),
|
||||
slog.String("code", e.Code),
|
||||
slog.String("message", e.Message),
|
||||
)
|
||||
@ -440,8 +444,8 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
// Pull pulls the model with the given name from the remote registry into the
|
||||
// cache.
|
||||
//
|
||||
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
|
||||
// chunks of the specified size, and then reassembled and verified. This is
|
||||
// For layers larger then [Registry.ChunkingThreshold], the layer is downloaded
|
||||
// in chunks of the specified size, and then reassembled and verified. This is
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
@ -489,94 +493,93 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
t.update(l, 0, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
func() {
|
||||
var wg sync.WaitGroup
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
// Chunksum stream interrupted. Note in trace
|
||||
// log and let in-flight downloads complete.
|
||||
// This will naturally trigger ErrIncomplete
|
||||
// since received < expected bytes.
|
||||
t.update(l, 0, err)
|
||||
break
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// Close the chunked writer when all chunks are
|
||||
// downloaded.
|
||||
//
|
||||
// This is done as a background task in the
|
||||
// group to allow the next layer to start while
|
||||
// we wait for the final chunk in this layer to
|
||||
// complete. It also ensures this is done
|
||||
// before we exit Pull.
|
||||
g.Go(func() error {
|
||||
wg.Wait()
|
||||
chunked.Close()
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
cs.Digest,
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
)
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
received.Add(cs.Chunk.Size())
|
||||
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Chunksum stream interrupted. Note in trace
|
||||
// log and let in-flight downloads complete.
|
||||
// This will naturally trigger ErrIncomplete
|
||||
// since received < expected bytes.
|
||||
t.update(l, 0, err)
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
cs.Digest,
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
)
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
received.Add(cs.Chunk.Size())
|
||||
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
received.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
t.update(l, 0, err)
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
|
||||
received.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
t.update(l, 0, err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
})
|
||||
}
|
||||
|
||||
// Close writer immediately after downloads finish, not at Pull
|
||||
// exit. Using defer would keep file descriptors open until all
|
||||
// layers complete, potentially exhausting system limits with
|
||||
// many layers.
|
||||
//
|
||||
// The WaitGroup tracks when all chunks finish downloading,
|
||||
// allowing precise writer closure in a background goroutine.
|
||||
// Each layer briefly uses one extra goroutine while at most
|
||||
// maxStreams()-1 chunks download in parallel.
|
||||
//
|
||||
// This caps file descriptors at maxStreams() instead of
|
||||
// growing with layer count.
|
||||
g.Go(func() error {
|
||||
wg.Wait()
|
||||
chunked.Close()
|
||||
return nil
|
||||
})
|
||||
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
@ -592,8 +595,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return c.Link(m.Name, md)
|
||||
}
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
// Unlink applies Mask to the name and passes it to the cache's Unlink method.
|
||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
@ -973,7 +975,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
|
||||
return nil, ErrModelNotFound
|
||||
}
|
||||
|
||||
re.Status = res.StatusCode
|
||||
re.status = res.StatusCode
|
||||
return nil, &re
|
||||
}
|
||||
return res, nil
|
||||
|
@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func checkErrCode(t *testing.T, err error, status int, code string) {
|
||||
t.Helper()
|
||||
var e *Error
|
||||
if !errors.As(err, &e) || e.Status != status || e.Code != code {
|
||||
if !errors.As(err, &e) || e.status != status || e.Code != code {
|
||||
t.Errorf("err = %v; want %v %v", err, status, code)
|
||||
}
|
||||
}
|
||||
|
@ -37,7 +37,25 @@ type traceKey struct{}
|
||||
// WithTrace returns a context derived from ctx that uses t to report trace
|
||||
// events.
|
||||
func WithTrace(ctx context.Context, t *Trace) context.Context {
|
||||
return context.WithValue(ctx, traceKey{}, t)
|
||||
old := traceFromContext(ctx)
|
||||
if old == t {
|
||||
// No change, return the original context. This also prevents
|
||||
// infinite recursion below, if the caller passes the same
|
||||
// Trace.
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Create a new Trace that wraps the old one, if any. If we used the
|
||||
// same pointer t, we end up with a recursive structure.
|
||||
composed := &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if old != nil {
|
||||
old.update(l, n, err)
|
||||
}
|
||||
t.update(l, n, err)
|
||||
},
|
||||
}
|
||||
return context.WithValue(ctx, traceKey{}, composed)
|
||||
}
|
||||
|
||||
var emptyTrace = &Trace{}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||
)
|
||||
|
||||
// Local implements an http.Handler for handling local Ollama API model
|
||||
@ -323,8 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
go func() (err error) {
|
||||
defer func() { done <- err }()
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := s.Client.Pull(ctx, p.model())
|
||||
var oe *ollama.Error
|
||||
if errors.As(err, &oe) && oe.Temporary() {
|
||||
s.Logger.Error("downloading", "model", p.model(), "error", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
for {
|
||||
|
@ -184,7 +184,6 @@ func TestServerPull(t *testing.T) {
|
||||
|
||||
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Errorf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
@ -203,12 +202,7 @@ func TestServerPull(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
|
@ -42,12 +42,6 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
|
||||
var useClient2 = experimentEnabled("client2")
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
@ -1274,12 +1268,9 @@ func Serve(ln net.Listener) error {
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
rc, err = ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc, err = ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user