mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
server/internal/client/ollama: cache completed chunks (#9933)
This change adds tracking of download chunks during the pull process so that subsequent pulls can skip downloading already completed chunks. This works across restarts of ollama. Currently, download state will be lost if a prune is triggered during a pull (e.g. restart or remove). This issue should be addressed in a follow-up PR.
This commit is contained in:
parent
b2a465296d
commit
ef27d52e79
@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
var re *Error
|
||||
if !errors.As(err, &re) {
|
||||
return false
|
||||
}
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
break
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
cs.Digest,
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
)
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
received.Add(cs.Chunk.Size())
|
||||
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
|
||||
received.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||
t.update(l, 0, err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err
|
||||
}
|
||||
if received.Load() != expected {
|
||||
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
||||
}
|
||||
|
||||
md := blob.DigestFromBytes(m.Data)
|
||||
@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||
return func(yield func(*Layer) bool) {
|
||||
if !yield(m.Config) {
|
||||
return
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if !yield(l) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() int64 {
|
||||
var size int64
|
||||
if m.Config != nil {
|
||||
size += m.Config.Size
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
size += l.Size
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
//
|
||||
// NOTE: It adds an empty config object to the manifest, which is required by
|
||||
@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Example:
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
|
@ -9,17 +9,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkNotExist(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullNotCached(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
var c *blob.DiskCache
|
||||
var rc *Registry
|
||||
|
||||
d := blob.DigestFromBytes("some data")
|
||||
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
io.WriteString(w, "some data")
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
|
||||
// Confirm successful download
|
||||
info, err := c.Get(d)
|
||||
check(err)
|
||||
if info.Digest != d {
|
||||
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
|
||||
}
|
||||
if info.Size != 9 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 9)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(c.GetFile(d))
|
||||
check(err)
|
||||
if string(data) != "some data" {
|
||||
t.Errorf("data = %q; want %q", data, "exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
|
||||
}
|
||||
})
|
||||
|
||||
var errs []error
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", d, n, err)
|
||||
reads = append(reads, n)
|
||||
errs = append(errs, err)
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("pairs = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
|
||||
// interfere with pulling layers that are not cached
|
||||
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
x := blob.DigestFromBytes("xxxxxx")
|
||||
e := blob.DigestFromBytes("exists")
|
||||
y := blob.DigestFromBytes("yyyyyy")
|
||||
|
||||
for i := range 10 {
|
||||
t.Logf("iteration %d", i)
|
||||
|
||||
digests := []blob.Digest{x, e, y}
|
||||
|
||||
rand.Shuffle(len(digests), func(i, j int) {
|
||||
digests[i], digests[j] = digests[j], digests[i]
|
||||
})
|
||||
|
||||
manifest := fmt.Sprintf(`{
|
||||
"layers": [
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6}
|
||||
]
|
||||
}`, digests[0], digests[1], digests[2])
|
||||
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch path.Base(r.URL.Path) {
|
||||
case "latest":
|
||||
io.WriteString(w, manifest)
|
||||
case x.String():
|
||||
io.WriteString(w, "xxxxxx")
|
||||
case e.String():
|
||||
io.WriteString(w, "exists")
|
||||
case y.String():
|
||||
io.WriteString(w, "yyyyyy")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected request: %v", r))
|
||||
}
|
||||
})
|
||||
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", l, n, err)
|
||||
},
|
||||
})
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, d := range digests {
|
||||
info, err := c.Get(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%v): %v", d, err)
|
||||
}
|
||||
if info.Size != 6 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestCanRetry(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{nil, false},
|
||||
{errors.New("x"), false},
|
||||
{ErrCached, false},
|
||||
{ErrManifestInvalid, false},
|
||||
{ErrNameInvalid, false},
|
||||
{&Error{Status: 100}, false},
|
||||
{&Error{Status: 500}, true},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
if got := canRetry(tt.err); got != tt.want {
|
||||
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
check := testutil.Checker(t)
|
||||
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
// make a blob and link it
|
||||
d := blob.DigestFromBytes("{}")
|
||||
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||
check(err)
|
||||
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||
check(err)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
_, err = rc.ResolveLocal("single")
|
||||
check(err)
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
check(err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPullChunksums(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||
// with the checksum of its sha256 hash. The checksum is:
|
||||
//
|
||||
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||
//
|
||||
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||
// to be the most readable and maintainable approach. The sum is consistently
|
||||
// used in the tests and unique so searches do not yield false positives.
|
||||
|
||||
content := "hello"
|
||||
var chunksums string
|
||||
contentDigest := func() blob.Digest {
|
||||
return blob.DigestFromBytes(content)
|
||||
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
t.Helper()
|
||||
if got := req.URL.Path; got != path {
|
||||
t.Errorf("URL = %q, want %q", got, path)
|
||||
}
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||
w.Header().Set("Content-Location", loc)
|
||||
io.WriteString(w, chunksums)
|
||||
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||
default:
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
if req.Method != method {
|
||||
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||
}
|
||||
}
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(h)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("Update: %v %d %v", l, n, err)
|
||||
mu.Lock()
|
||||
reads = append(reads, n)
|
||||
mu.Unlock()
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
},
|
||||
})
|
||||
|
||||
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||
blob.DigestFromBytes("hel"),
|
||||
blob.DigestFromBytes("lo"),
|
||||
)
|
||||
err := rc.Pull(ctx, "test")
|
||||
check(err)
|
||||
wantReads := []int64{
|
||||
0, // initial signaling of layer pull starting
|
||||
3, // first chunk read
|
||||
2, // second chunk read
|
||||
}
|
||||
if !slices.Equal(reads, wantReads) {
|
||||
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||
rc := &Registry{
|
||||
Cache: cache,
|
||||
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, s.Listener.Addr().String())
|
||||
},
|
||||
}},
|
||||
}
|
||||
return rc, ctx
|
||||
}
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "test")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("test")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
for i := range mg.Layers {
|
||||
_, err = c.Get(mg.Layers[i].Digest)
|
||||
if err != nil {
|
||||
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||
func TestPullChunked(t *testing.T) {
|
||||
var steps atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch steps.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
t.Logf("writing c")
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// missing chunks
|
||||
content = "llama"
|
||||
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||
err = rc.Pull(ctx, "missingchunks")
|
||||
if err == nil {
|
||||
t.Error("expected error because of missing chunks")
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
testutil.Check(t, err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
testutil.Check(t, err)
|
||||
|
||||
if g := steps.Load(); g != 4 {
|
||||
t.Fatalf("got %d steps, want 4", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullCached(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
// Premeptively cache the blob
|
||||
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
check(err)
|
||||
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||
check(err)
|
||||
|
||||
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPullManifestError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var got *Error
|
||||
if !errors.Is(err, ErrModelNotFound) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `!`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var want *json.SyntaxError
|
||||
if !errors.As(err, &want) {
|
||||
t.Fatalf("err = %T, want %T", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerChecksumError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
var written atomic.Int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
var got *Error
|
||||
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||
t.Fatalf("err = %v, want %v", err, got)
|
||||
}
|
||||
|
||||
if g := written.Load(); g != 1 {
|
||||
t.Fatalf("wrote %d bytes, want 1", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullChunksumStreamError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
|
||||
// Write one valid chunksum and one invalid chunksum
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||
fmt.Fprint(w, "sha256:!") // invalid
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(got, ErrIncomplete) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = f.w.Write(p)
|
||||
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||
return
|
||||
}
|
||||
|
||||
func TestPullChunksumStreaming(t *testing.T) {
|
||||
csr, csw := io.Pipe()
|
||||
defer csw.Close()
|
||||
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||
_, err := io.Copy(fw, csr)
|
||||
if err != nil {
|
||||
t.Errorf("copy: %v", err)
|
||||
}
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
update := make(chan int64, 1)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if n > 0 {
|
||||
update <- n
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||
}()
|
||||
|
||||
// Send first chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
if g := <-update; g != 2 {
|
||||
t.Fatalf("got %d, want 2", g)
|
||||
}
|
||||
|
||||
// now send the second chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
if g := <-update; g != 1 {
|
||||
t.Fatalf("got %d, want 1", g)
|
||||
}
|
||||
csw.Close()
|
||||
testutil.Check(t, <-errc)
|
||||
}
|
||||
|
||||
func TestPullChunksumsCached(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1 // force serial processing of chunksums
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the pull after the first chunksum is processed, but before
|
||||
// the second chunksum is processed (which is waiting because
|
||||
// MaxStreams=1). This should cause the second chunksum to error out
|
||||
// leaving the blob incomplete.
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Reset state and pull again to ensure the blob chunks that should
|
||||
// have been cached are, and the remaining chunk was downloaded, making
|
||||
// the blob complete.
|
||||
step.Store(0)
|
||||
var written atomic.Int64
|
||||
var cached atomic.Int64
|
||||
ctx = WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if errors.Is(err, ErrCached) {
|
||||
cached.Add(n)
|
||||
}
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
check(err)
|
||||
|
||||
if g := written.Load(); g != 3 {
|
||||
t.Fatalf("wrote %d bytes, want 3", g)
|
||||
}
|
||||
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||
t.Fatalf("cached %d bytes, want 3", g)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user