diff --git a/cmd/cmd.go b/cmd/cmd.go index abb4806b5..041d1c5d6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "maps" "math" "net" "net/http" @@ -21,6 +22,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -773,57 +775,100 @@ func CopyHandler(cmd *cobra.Command, args []string) error { } func PullHandler(cmd *cobra.Command, args []string) error { - insecure, err := cmd.Flags().GetBool("insecure") - if err != nil { - return err - } + name := args[0] // cobra should be cofigured to always pass 1 arg client, err := api.ClientFromEnvironment() if err != nil { return err } - p := progress.NewProgress(os.Stderr) - defer p.Stop() + ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt) + defer cancel() - bars := make(map[string]*progress.Bar) - - var status string - var spinner *progress.Spinner - - fn := func(resp api.ProgressResponse) error { - if resp.Digest != "" { - if spinner != nil { - spinner.Stop() + errc := make(chan error, 1) + var mu sync.Mutex + progressLocked := make(map[string][2]int64) // digest -> [completed, total] + go func() { + p := &api.PullRequest{Name: name} + errc <- client.Pull(ctx, p, func(up api.ProgressResponse) error { + if up.Digest == "" && up.Status != "" { + // A status with no digest is a terminal + // status. Give up and return an error. + // + // But first: Strip any "error: " prefix so it + // does not stutter with Cobra's "Error: " + // prefix. + // + // Our future client will handle the stream + // updates and errors properly. This works for + // now though. + status := strings.TrimPrefix(up.Status, "error: ") + return errors.New(status) } + mu.Lock() + progressLocked[up.Digest] = [2]int64{up.Completed, up.Total} + mu.Unlock() + return nil + }) + }() - bar, ok := bars[resp.Digest] - if !ok { - bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) - bars[resp.Digest] = bar - p.Add(resp.Digest, bar) - } + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() - bar.Set(resp.Completed) - } else if status != resp.Status { - if spinner != nil { - spinner.Stop() - } + fmt.Fprint(os.Stderr, escHideCursor) + defer fmt.Fprint(os.Stderr, escShowCursor) - status = resp.Status - spinner = progress.NewSpinner(status) - p.Add(status, spinner) + progress := make(map[string][2]int64) + flushProgress := func() { + mu.Lock() + maps.Copy(progress, progressLocked) + mu.Unlock() + var completed, total int64 + for _, v := range progress { + completed += v[0] + total += v[1] + } + if total > 0 { + fmt.Fprintf(os.Stderr, "\rDownloading %s %s/%s (%0.1f%%)%s", + name, + formatNatural(completed), + formatNatural(total), + 100*float64(completed)/float64(total), + escClearToEOL, + ) } - - return nil } - - request := api.PullRequest{Name: args[0], Insecure: insecure} - if err := client.Pull(cmd.Context(), &request, fn); err != nil { - return err + for { + select { + case <-t.C: + flushProgress() + case err := <-errc: + flushProgress() + fmt.Fprintln(os.Stderr) + return err + } } +} - return nil +const ( + escClearToEOL = "\x1b[K" + escHideCursor = "\x1b[?25l" + escShowCursor = "\x1b[?25h" +) + +// formatNatural formats a number of bytes into SI units. This aligns with +// other downloaders like cURL. +func formatNatural(n int64) string { + switch { + case n < 1024: + return fmt.Sprintf("%d", n) + case n < 1024*1024: + return fmt.Sprintf("%.1fk", float64(n)/1024) + case n < 1024*1024*1024: + return fmt.Sprintf("%.1fM", float64(n)/(1024*1024)) + default: + return fmt.Sprintf("%.1fG", float64(n)/(1024*1024*1024)) + } } type generateContextKey string diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 409932bfd..3ba0f4fa0 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -107,15 +107,19 @@ func DefaultCache() (*blob.DiskCache, error) { // // In both cases, the code field is optional and may be empty. type Error struct { - Status int `json:"-"` // TODO(bmizerany): remove this + status int `json:"-"` // TODO(bmizerany): remove this Code string `json:"code"` Message string `json:"message"` } +func (e *Error) Temporary() bool { + return e.status > 500 +} + func (e *Error) Error() string { var b strings.Builder b.WriteString("registry responded with status ") - b.WriteString(strconv.Itoa(e.Status)) + b.WriteString(strconv.Itoa(e.status)) if e.Code != "" { b.WriteString(": code ") b.WriteString(e.Code) @@ -129,7 +133,7 @@ func (e *Error) Error() string { func (e *Error) LogValue() slog.Value { return slog.GroupValue( - slog.Int("status", e.Status), + slog.Int("status", e.status), slog.String("code", e.Code), slog.String("message", e.Message), ) @@ -440,8 +444,8 @@ func (r *trackingReader) Read(p []byte) (n int, err error) { // Pull pulls the model with the given name from the remote registry into the // cache. // -// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in -// chunks of the specified size, and then reassembled and verified. This is +// For layers larger then [Registry.ChunkingThreshold], the layer is downloaded +// in chunks of the specified size, and then reassembled and verified. This is // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". func (r *Registry) Pull(ctx context.Context, name string) error { @@ -489,94 +493,93 @@ func (r *Registry) Pull(ctx context.Context, name string) error { continue } - var wg sync.WaitGroup - chunked, err := c.Chunked(l.Digest, l.Size) - if err != nil { - t.update(l, 0, err) - continue - } - - for cs, err := range r.chunksums(ctx, name, l) { + func() { + var wg sync.WaitGroup + chunked, err := c.Chunked(l.Digest, l.Size) if err != nil { - // Chunksum stream interrupted. Note in trace - // log and let in-flight downloads complete. - // This will naturally trigger ErrIncomplete - // since received < expected bytes. t.update(l, 0, err) - break + return } + defer func() { + // Close the chunked writer when all chunks are + // downloaded. + // + // This is done as a background task in the + // group to allow the next layer to start while + // we wait for the final chunk in this layer to + // complete. It also ensures this is done + // before we exit Pull. + g.Go(func() error { + wg.Wait() + chunked.Close() + return nil + }) + }() - cacheKey := fmt.Sprintf( - "v1 pull chunksum %s %s %d-%d", - l.Digest, - cs.Digest, - cs.Chunk.Start, - cs.Chunk.End, - ) - cacheKeyDigest := blob.DigestFromBytes(cacheKey) - _, err := c.Get(cacheKeyDigest) - if err == nil { - received.Add(cs.Chunk.Size()) - t.update(l, cs.Chunk.Size(), ErrCached) - continue - } + for cs, err := range r.chunksums(ctx, name, l) { + if err != nil { + // Chunksum stream interrupted. Note in trace + // log and let in-flight downloads complete. + // This will naturally trigger ErrIncomplete + // since received < expected bytes. + t.update(l, 0, err) + break + } - wg.Add(1) - g.Go(func() (err error) { - defer func() { - if err == nil { - // Ignore cache key write errors for now. We've already - // reported to trace that the chunk is complete. - // - // Ideally, we should only report completion to trace - // after successful cache commit. This current approach - // works but could trigger unnecessary redownloads if - // the checkpoint key is missing on next pull. - // - // Not incorrect, just suboptimal - fix this in a - // future update. - _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + cacheKey := fmt.Sprintf( + "v1 pull chunksum %s %s %d-%d", + l.Digest, + cs.Digest, + cs.Chunk.Start, + cs.Chunk.End, + ) + cacheKeyDigest := blob.DigestFromBytes(cacheKey) + _, err := c.Get(cacheKeyDigest) + if err == nil { + received.Add(cs.Chunk.Size()) + t.update(l, cs.Chunk.Size(), ErrCached) + continue + } - received.Add(cs.Chunk.Size()) - } else { - t.update(l, 0, err) + wg.Add(1) + g.Go(func() (err error) { + defer func() { + if err == nil { + // Ignore cache key write errors for now. We've already + // reported to trace that the chunk is complete. + // + // Ideally, we should only report completion to trace + // after successful cache commit. This current approach + // works but could trigger unnecessary redownloads if + // the checkpoint key is missing on next pull. + // + // Not incorrect, just suboptimal - fix this in a + // future update. + _ = blob.PutBytes(c, cacheKeyDigest, cacheKey) + + received.Add(cs.Chunk.Size()) + } else { + t.update(l, 0, err) + } + wg.Done() + }() + + req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) + if err != nil { + return err } - wg.Done() - }() + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) + res, err := sendRequest(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() - req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) - if err != nil { - return err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) - res, err := sendRequest(r.client(), req) - if err != nil { - return err - } - defer res.Body.Close() - - body := &trackingReader{l: l, r: res.Body, update: t.update} - return chunked.Put(cs.Chunk, cs.Digest, body) - }) - } - - // Close writer immediately after downloads finish, not at Pull - // exit. Using defer would keep file descriptors open until all - // layers complete, potentially exhausting system limits with - // many layers. - // - // The WaitGroup tracks when all chunks finish downloading, - // allowing precise writer closure in a background goroutine. - // Each layer briefly uses one extra goroutine while at most - // maxStreams()-1 chunks download in parallel. - // - // This caps file descriptors at maxStreams() instead of - // growing with layer count. - g.Go(func() error { - wg.Wait() - chunked.Close() - return nil - }) + body := &trackingReader{l: l, r: res.Body, update: t.update} + return chunked.Put(cs.Chunk, cs.Digest, body) + }) + } + }() } if err := g.Wait(); err != nil { return err @@ -592,8 +595,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { return c.Link(m.Name, md) } -// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified -// before attempting to unlink the model. +// Unlink applies Mask to the name and passes it to the cache's Unlink method. func (r *Registry) Unlink(name string) (ok bool, _ error) { n, err := r.parseName(name) if err != nil { @@ -973,7 +975,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) return nil, ErrModelNotFound } - re.Status = res.StatusCode + re.status = res.StatusCode return nil, &re } return res, nil diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 80d39b765..e7f5783c3 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) { func checkErrCode(t *testing.T, err error, status int, code string) { t.Helper() var e *Error - if !errors.As(err, &e) || e.Status != status || e.Code != code { + if !errors.As(err, &e) || e.status != status || e.Code != code { t.Errorf("err = %v; want %v %v", err, status, code) } } diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index 69435c406..d0da332eb 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -37,7 +37,25 @@ type traceKey struct{} // WithTrace returns a context derived from ctx that uses t to report trace // events. func WithTrace(ctx context.Context, t *Trace) context.Context { - return context.WithValue(ctx, traceKey{}, t) + old := traceFromContext(ctx) + if old == t { + // No change, return the original context. This also prevents + // infinite recursion below, if the caller passes the same + // Trace. + return ctx + } + + // Create a new Trace that wraps the old one, if any. If we used the + // same pointer t, we end up with a recursive structure. + composed := &Trace{ + Update: func(l *Layer, n int64, err error) { + if old != nil { + old.update(l, n, err) + } + t.update(l, n, err) + }, + } + return context.WithValue(ctx, traceKey{}, composed) } var emptyTrace = &Trace{} diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 1910b1877..c366358a3 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -16,6 +16,7 @@ import ( "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" + "github.com/ollama/ollama/server/internal/internal/backoff" ) // Local implements an http.Handler for handling local Ollama API model @@ -323,8 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { }) done := make(chan error, 1) - go func() { - done <- s.Client.Pull(ctx, p.model()) + go func() (err error) { + defer func() { done <- err }() + for _, err := range backoff.Loop(ctx, 3*time.Second) { + if err != nil { + return err + } + err := s.Client.Pull(ctx, p.model()) + var oe *ollama.Error + if errors.As(err, &oe) && oe.Temporary() { + s.Logger.Error("downloading", "model", p.model(), "error", err) + continue + } + return err + } + return nil }() for { diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 3f20e518a..25bb9da42 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -184,7 +184,6 @@ func TestServerPull(t *testing.T) { checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { t.Helper() - if got.Code != 200 { t.Errorf("Code = %d; want 200", got.Code) } @@ -203,12 +202,7 @@ func TestServerPull(t *testing.T) { } } - got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) - checkResponse(got, ` - {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} - `) - - got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) + got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`) checkResponse(got, ` {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} {"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} diff --git a/server/routes.go b/server/routes.go index 92336af00..67dc50fbb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -42,12 +42,6 @@ import ( "github.com/ollama/ollama/version" ) -func experimentEnabled(name string) bool { - return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) -} - -var useClient2 = experimentEnabled("client2") - var mode string = gin.DebugMode type Server struct { @@ -1274,12 +1268,9 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} var rc *ollama.Registry - if useClient2 { - var err error - rc, err = ollama.DefaultRegistry() - if err != nil { - return err - } + rc, err = ollama.DefaultRegistry() + if err != nil { + return err } h, err := s.GenerateRoutes(rc)