cmd: compact pull progress and make client2 the default

This makes client2 the default for the pull command, and updates the
progress output to be more compact and easier to read, omitting details
relevant only to the original client implementation.

Previously, the pull command would show a progress bar for each layer
like this:

	; ollama pull gemma3
	pulling manifest
	pulling aeda25e63ebd... 100% ▕██████████████████████████ 3.3 GB
	pulling e0a42594d802... 100% ▕██████████████████████████  358 B
	pulling dd084c7d92a3... 100% ▕██████████████████████████ 8.4 KB
	pulling 3116c5225075... 100% ▕██████████████████████████   77 B
	pulling b6ae5839783f... 100% ▕██████████████████████████  489 B
	verifying sha256 digest
	writing manifest
	success

This changes the output to look like this:

	; ollama pull gemma3
	Downloading gemma3 123M/4.5G (2.7%)

As a side note, the progress bar is not safe for concurrent use and so
the pull command had race data races. This should fix any data races in
the pull command, but it does not affect any other commands using the
progress package.

Also, move the call to Chunker.Close into a defer statement to ensure it
it is called. Previously, care needed to be taken to ensure there we no
panics, but without 100% control because we called t.update which is
user-defined and could panic.
This commit is contained in:
Blake Mizerany 2025-03-31 17:05:22 -07:00
parent c001b98087
commit b7f1a395ea
7 changed files with 211 additions and 147 deletions

View File

@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
"maps"
"math"
"net"
"net/http"
@ -21,6 +22,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@ -773,57 +775,100 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
func PullHandler(cmd *cobra.Command, args []string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
}
name := args[0] // cobra should be cofigured to always pass 1 arg
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt)
defer cancel()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if spinner != nil {
spinner.Stop()
errc := make(chan error, 1)
var mu sync.Mutex
progressLocked := make(map[string][2]int64) // digest -> [completed, total]
go func() {
p := &api.PullRequest{Name: name}
errc <- client.Pull(ctx, p, func(up api.ProgressResponse) error {
if up.Digest == "" && up.Status != "" {
// A status with no digest is a terminal
// status. Give up and return an error.
//
// But first: Strip any "error: " prefix so it
// does not stutter with Cobra's "Error: "
// prefix.
//
// Our future client will handle the stream
// updates and errors properly. This works for
// now though.
status := strings.TrimPrefix(up.Status, "error: ")
return errors.New(status)
}
mu.Lock()
progressLocked[up.Digest] = [2]int64{up.Completed, up.Total}
mu.Unlock()
return nil
})
}()
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
fmt.Fprint(os.Stderr, escHideCursor)
defer fmt.Fprint(os.Stderr, escShowCursor)
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
progress := make(map[string][2]int64)
flushProgress := func() {
mu.Lock()
maps.Copy(progress, progressLocked)
mu.Unlock()
var completed, total int64
for _, v := range progress {
completed += v[0]
total += v[1]
}
if total > 0 {
fmt.Fprintf(os.Stderr, "\rDownloading %s %s/%s (%0.1f%%)%s",
name,
formatNatural(completed),
formatNatural(total),
100*float64(completed)/float64(total),
escClearToEOL,
)
}
return nil
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
for {
select {
case <-t.C:
flushProgress()
case err := <-errc:
flushProgress()
fmt.Fprintln(os.Stderr)
return err
}
}
}
return nil
const (
escClearToEOL = "\x1b[K"
escHideCursor = "\x1b[?25l"
escShowCursor = "\x1b[?25h"
)
// formatNatural formats a number of bytes into SI units. This aligns with
// other downloaders like cURL.
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d", n)
case n < 1024*1024:
return fmt.Sprintf("%.1fk", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1fM", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1fG", float64(n)/(1024*1024*1024))
}
}
type generateContextKey string

View File

@ -107,15 +107,19 @@ func DefaultCache() (*blob.DiskCache, error) {
//
// In both cases, the code field is optional and may be empty.
type Error struct {
Status int `json:"-"` // TODO(bmizerany): remove this
status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
func (e *Error) Temporary() bool {
return e.status > 500
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
b.WriteString(strconv.Itoa(e.status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
@ -129,7 +133,7 @@ func (e *Error) Error() string {
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.Int("status", e.status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
@ -440,8 +444,8 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
// chunks of the specified size, and then reassembled and verified. This is
// For layers larger then [Registry.ChunkingThreshold], the layer is downloaded
// in chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error {
@ -489,94 +493,93 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
continue
}
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
func() {
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
return
}
defer func() {
// Close the chunked writer when all chunks are
// downloaded.
//
// This is done as a background task in the
// group to allow the next layer to start while
// we wait for the final chunk in this layer to
// complete. It also ensures this is done
// before we exit Pull.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}()
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
wg.Done()
}()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
}()
}
if err := g.Wait(); err != nil {
return err
@ -592,8 +595,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
// Unlink applies Mask to the name and passes it to the cache's Unlink method.
func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
@ -973,7 +975,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
return nil, ErrModelNotFound
}
re.Status = res.StatusCode
re.status = res.StatusCode
return nil, &re
}
return res, nil

View File

@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}

View File

@ -37,7 +37,25 @@ type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
}
var emptyTrace = &Trace{}

View File

@ -16,6 +16,7 @@ import (
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
)
// Local implements an http.Handler for handling local Ollama API model
@ -323,8 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
done := make(chan error, 1)
go func() {
done <- s.Client.Pull(ctx, p.model())
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
s.Logger.Error("downloading", "model", p.model(), "error", err)
continue
}
return err
}
return nil
}()
for {

View File

@ -184,7 +184,6 @@ func TestServerPull(t *testing.T) {
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code)
}
@ -203,12 +202,7 @@ func TestServerPull(t *testing.T) {
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}

View File

@ -42,12 +42,6 @@ import (
"github.com/ollama/ollama/version"
)
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode
type Server struct {
@ -1274,12 +1268,9 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
var rc *ollama.Registry
if useClient2 {
var err error
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
h, err := s.GenerateRoutes(rc)