diff --git a/server/download.go b/server/download.go index 8b5b577fd..45483ba68 100644 --- a/server/download.go +++ b/server/download.go @@ -44,17 +44,19 @@ type blobDownload struct { context.CancelFunc - done bool + done chan struct{} err error references atomic.Int32 } type blobDownloadPart struct { - N int - Offset int64 - Size int64 - Completed int64 - lastUpdated time.Time + N int + Offset int64 + Size int64 + Completed atomic.Int64 + + lastUpdatedMu sync.Mutex + lastUpdated time.Time *blobDownload `json:"-"` } @@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string { } func (p *blobDownloadPart) StartsAt() int64 { - return p.Offset + p.Completed + return p.Offset + p.Completed.Load() } func (p *blobDownloadPart) StopsAt() int64 { @@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 { func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) + p.lastUpdatedMu.Lock() p.lastUpdated = time.Now() + p.lastUpdatedMu.Unlock() return n, nil } @@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r return err } + b.done = make(chan struct{}) + for _, partFilePath := range partFilePaths { part, err := b.readPart(partFilePath) if err != nil { @@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } b.Total += part.Size - b.Completed.Add(part.Completed) + b.Completed.Add(part.Completed.Load()) b.Parts = append(b.Parts, part) } @@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { + defer close(b.done) b.err = b.run(ctx, requestURL, opts) } @@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis g.SetLimit(numDownloadParts) for i := range b.Parts { part := b.Parts[i] - if part.Completed == part.Size { + if part.Completed.Load() == part.Size { continue } @@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, directURL, w, part, opts) + err = b.downloadChunk(inner, directURL, w, part) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): // return immediately if the context is canceled or the device is out of space @@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } - b.done = true return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - headers := make(http.Header) - headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) - resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) + if err != nil { + return err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() - n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed) + n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { // rollback progress b.Completed.Add(-n) return err } - part.Completed += n + part.Completed.Add(n) if err := b.writePart(part.Name(), part); err != nil { return err } @@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w for { select { case <-ticker.C: - if part.Completed >= part.Size { + if part.Completed.Load() >= part.Size { return nil } - if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second { + part.lastUpdatedMu.Lock() + lastUpdated := part.lastUpdated + part.lastUpdatedMu.Unlock() + + if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second { const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) // reset last updated + part.lastUpdatedMu.Lock() part.lastUpdated = time.Time{} + part.lastUpdatedMu.Unlock() return errPartStalled } case <-ctx.Done(): @@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ticker := time.NewTicker(60 * time.Millisecond) for { select { + case <-b.done: + return b.err case <-ticker.C: fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), @@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) Total: b.Total, Completed: b.Completed.Load(), }) - - if b.done || b.err != nil { - return b.err - } case <-ctx.Done(): return ctx.Err() }