server/internal/client/ollama: cache completed chunks (#9933)

This change adds tracking of download chunks during the pull process so
that subsequent pulls can skip downloading already completed chunks.
This works across restarts of ollama.

Currently, download state will be lost if a prune is triggered during a
pull (e.g. restart or remove). This issue should be addressed in a
follow-up PR.
This commit is contained in:
Blake Mizerany 2025-03-30 23:54:54 -07:00 committed by GitHub
parent b2a465296d
commit ef27d52e79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 443 additions and 303 deletions

View File

@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
return err
}
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
break
}
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
t.update(l, 0, err)
}
wg.Done()
}()
@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
}
md := blob.DigestFromBytes(m.Data)
@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
size += m.Config.Size
}
for _, l := range m.Layers {
size += l.Size
}
return size
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
// The response is a sequence of chunksums.
//
// Example:
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// The blobURL is the URL to download the chunks from.
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,

View File

@ -9,17 +9,14 @@ import (
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/testutil"
@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), "x")
err := rc.Pull(t.Context(), "http://example.com/a/b")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
check := testutil.Checker(t)
rc, _ := newRegistryClient(t, nil)
// make a blob and link it
d := blob.DigestFromBytes("{}")
err := blob.PutBytes(rc.Cache, d, "{}")
check(err)
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
check(err)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
_, err = rc.ResolveLocal("single")
check(err)
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
check(err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
// Many tests from here out, in this file are based on a single blob, "abc",
// with the checksum of its sha256 hash. The checksum is:
//
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
//
// Using the literal value instead of a constant with fmt.Xprintf calls proved
// to be the most readable and maintainable approach. The sum is consistently
// used in the tests and unique so searches do not yield false positives.
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
func checkRequest(t *testing.T, req *http.Request, method, path string) {
t.Helper()
if got := req.URL.Path; got != path {
t.Errorf("URL = %q, want %q", got, path)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
if req.Method != method {
t.Errorf("Method = %q, want %q", req.Method, method)
}
}
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(h)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
t.Log("trace:", l.Digest.Short(), n, err)
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
rc := &Registry{
Cache: cache,
HTTPClient: &http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, s.Listener.Addr().String())
},
}},
}
return rc, ctx
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
func TestPullChunked(t *testing.T) {
var steps atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch steps.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
t.Logf("writing c")
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
}
})
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
c.ChunkingThreshold = 1 // force chunking
err := c.Pull(ctx, "http://o.com/library/abc")
testutil.Check(t, err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
testutil.Check(t, err)
if g := steps.Load(); g != 4 {
t.Fatalf("got %d steps, want 4", g)
}
}
func TestPullCached(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
})
check := testutil.Checker(t)
// Premeptively cache the blob
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
check(err)
err = blob.PutBytes(c.Cache, d, []byte("abc"))
check(err)
// Pull only the manifest, which should be enough to resolve the cached blob
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
}
func TestPullManifestError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var got *Error
if !errors.Is(err, ErrModelNotFound) {
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
}
}
func TestPullLayerError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `!`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var want *json.SyntaxError
if !errors.As(err, &want) {
t.Fatalf("err = %T, want %T", err, want)
}
}
func TestPullLayerChecksumError(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3:
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1
c.ChunkingThreshold = 1 // force chunking
var written atomic.Int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
written.Add(n)
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
var got *Error
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
t.Fatalf("err = %v, want %v", err, got)
}
if g := written.Load(); g != 1 {
t.Fatalf("wrote %d bytes, want 1", g)
}
}
func TestPullChunksumStreamError(t *testing.T) {
var step atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
// Write one valid chunksum and one invalid chunksum
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
fmt.Fprint(w, "sha256:!") // invalid
case 3:
io.WriteString(w, "ab")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
got := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(got, ErrIncomplete) {
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
}
}
type flushAfterWriter struct {
w io.Writer
}
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
f.w.(http.Flusher).Flush() // panic if not a flusher
return
}
func TestPullChunksumStreaming(t *testing.T) {
csr, csw := io.Pipe()
defer csw.Close()
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
_, err := io.Copy(fw, csr)
if err != nil {
t.Errorf("copy: %v", err)
}
case 3:
io.WriteString(w, "ab")
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
update := make(chan int64, 1)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if n > 0 {
update <- n
}
},
})
errc := make(chan error, 1)
go func() {
errc <- c.Pull(ctx, "http://o.com/library/abc")
}()
// Send first chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
if g := <-update; g != 2 {
t.Fatalf("got %d, want 2", g)
}
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 {
t.Fatalf("got %d, want 1", g)
}
csw.Close()
testutil.Check(t, <-errc)
}
func TestPullChunksumsCached(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1 // force serial processing of chunksums
c.ChunkingThreshold = 1 // force chunking
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Cancel the pull after the first chunksum is processed, but before
// the second chunksum is processed (which is waiting because
// MaxStreams=1). This should cause the second chunksum to error out
// leaving the blob incomplete.
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if n > 0 {
cancel()
}
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(err, context.Canceled) {
t.Fatalf("err = %v, want %v", err, context.Canceled)
}
_, err = c.Cache.Resolve("o.com/library/abc:latest")
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want nil", err)
}
// Reset state and pull again to ensure the blob chunks that should
// have been cached are, and the remaining chunk was downloaded, making
// the blob complete.
step.Store(0)
var written atomic.Int64
var cached atomic.Int64
ctx = WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if errors.Is(err, ErrCached) {
cached.Add(n)
}
written.Add(n)
},
})
check := testutil.Checker(t)
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 3 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g)
}
}