From 10199c59879062d10056c285e1e1994286c929f8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 3 Oct 2023 16:52:49 -0700 Subject: [PATCH] replace done channel with file check --- server/download.go | 53 +++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/server/download.go b/server/download.go index e550a89a3..973e9eef9 100644 --- a/server/download.go +++ b/server/download.go @@ -31,10 +31,8 @@ type blobDownload struct { Total int64 Completed atomic.Int64 - *os.File Parts []*blobDownloadPart - done chan struct{} context.CancelFunc references atomic.Int32 } @@ -54,6 +52,14 @@ func (p *blobDownloadPart) Name() string { }, "-") } +func (p *blobDownloadPart) StartsAt() int64 { + return p.Offset + p.Completed +} + +func (p *blobDownloadPart) StopsAt() int64 { + return p.Offset + p.Size +} + func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { @@ -110,18 +116,16 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis ctx, b.CancelFunc = context.WithCancel(ctx) - b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644) + file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644) if err != nil { return err } - defer b.Close() + defer file.Close() - b.Truncate(b.Total) - - b.done = make(chan struct{}, 1) - defer close(b.done) + file.Truncate(b.Total) g, ctx := errgroup.WithContext(ctx) + // TODO(mxyng): download concurrency should be configurable g.SetLimit(64) for i := range b.Parts { part := b.Parts[i] @@ -132,7 +136,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis i := i g.Go(func() error { for try := 0; try < maxRetries; try++ { - err := b.downloadChunk(ctx, requestURL, i, opts) + w := io.NewOffsetWriter(file, part.StartsAt()) + err := b.downloadChunk(ctx, requestURL, w, part, opts) switch { case errors.Is(err, context.Canceled): return err @@ -152,31 +157,23 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis return err } - if err := b.Close(); err != nil { + // explicitly close the file so we can rename it + if err := file.Close(); err != nil { return err } for i := range b.Parts { - if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil { + if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil { return err } } - if err := os.Rename(b.File.Name(), b.Name); err != nil { - return err - } - - return nil + return os.Rename(file.Name(), b.Name) } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error { - part := b.Parts[i] - - offset := part.Offset + part.Completed - w := io.NewOffsetWriter(b.File, offset) - +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { headers := make(http.Header) - headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1)) + headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts) if err != nil { return err @@ -258,10 +255,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ticker := time.NewTicker(60 * time.Millisecond) for { select { - case <-b.done: - if b.Completed.Load() != b.Total { - return io.ErrUnexpectedEOF - } case <-ticker.C: case <-ctx.Done(): return ctx.Err() @@ -275,8 +268,10 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) }) if b.Completed.Load() >= b.Total { - <-b.done - return nil + // wait for the file to get renamed + if _, err := os.Stat(b.Name); err == nil { + return nil + } } } }