server/internal/client/ollama: confirm all chunksums were received (#9893)

If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
This commit is contained in:
Blake Mizerany 2025-03-19 14:59:57 -07:00 committed by GitHub
parent da0e345200
commit 2ddacd7516
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 134 additions and 78 deletions

View File

@ -37,7 +37,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names" "github.com/ollama/ollama/server/internal/internal/names"
_ "embed" _ "embed"
@ -213,12 +212,6 @@ type Registry struct {
// request. If zero, [DefaultChunkingThreshold] is used. // request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64 ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names // Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used. // to fully qualified names. If empty, [DefaultMask] is used.
Mask string Mask string
@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil { if err != nil {
return err return err
} }
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 { if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid) return fmt.Errorf("%w: no layers", ErrManifestInvalid)
} }
@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
exists := func(l *Layer) bool { // TODO(bmizerany): work to remove the need to do this
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
layers := m.Layers layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() { if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config) layers = append(layers, m.Config)
@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an // Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts. // understanding of work to be done before work starts.
t := traceFromContext(ctx) t := traceFromContext(ctx)
skip := make([]bool, len(layers)) for _, l := range layers {
for i, l := range layers {
t.update(l, 0, nil) t.update(l, 0, nil)
if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached)
}
} }
g, ctx := errgroup.WithContext(ctx) var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for i, l := range layers { for _, l := range layers {
if skip[i] { info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
t.update(l, l.Size, ErrCached)
continue continue
} }
@ -490,23 +481,26 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t.update(l, 0, err) t.update(l, 0, err)
continue continue
} }
// TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close() defer chunked.Close()
var progress atomic.Int64 var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
// Bad chunksums response, update tracing
// clients and then bail.
t.update(l, progress.Load(), err) t.update(l, progress.Load(), err)
break return err
} }
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }() defer func() {
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil { if err != nil {
return err err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
} }
err := func() error { t.update(l, progress.Load(), err)
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil { if err != nil {
return err return err
@ -518,35 +512,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
} }
defer res.Body.Close() defer res.Body.Close()
// Count bytes towards // Count bytes towards progress, as they
// progress, as they arrive, so // arrive, so that our bytes piggyback other
// that our bytes piggyback // chunk updates on completion.
// other chunk updates on
// completion.
// //
// This tactic is enough to // This tactic is enough to show "smooth"
// show "smooth" progress given // progress given the current CLI client. In
// the current CLI client. In // the near future, the server should report
// the near future, the server // download rate since it knows better than a
// should report download rate // client that is measuring rate based on
// since it knows better than // wall-clock time-since-last-update.
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress} body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body) return chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
}) })
} }
} }
@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
// store the manifest blob
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil { if err := blob.PutBytes(c, md, m.Data); err != nil {
return err return err
} }
// commit the manifest with a link
return c.Link(m.Name, md) return c.Link(m.Name, md)
} }
@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
} }
blobURL := res.Header.Get("Content-Location") blobURL := res.Header.Get("Content-Location")
var size int64
s := bufio.NewScanner(res.Body) s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords) s.Split(bufio.ScanWords)
for { for {
if !s.Scan() { if !s.Scan() {
if s.Err() != nil { if s.Err() != nil {
yield(chunksum{}, s.Err()) yield(chunksum{}, s.Err())
} else if size != l.Size {
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
} }
return return
} }
@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return return
} }
size += chunk.Size()
if size > l.Size {
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
return
}
cs := chunksum{ cs := chunksum{
URL: blobURL, URL: blobURL,
Chunk: chunk, Chunk: chunk,

View File

@ -17,6 +17,7 @@ import (
"reflect" "reflect"
"slices" "slices"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -70,7 +71,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// communication is attempted. // communication is attempted.
// //
// To simulate a network error, pass a handler that returns a 499 status code. // To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper() t.Helper()
c, err := blob.Open(t.TempDir()) c, err := blob.Open(t.TempDir())
@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
r := &Registry{ r := &Registry{
Cache: c, Cache: c,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(upstreamRegistry),
}, },
} }
@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) {
} }
}) })
} }
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
if !slices.Equal(reads, []int64{0, 3, 5}) {
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}