mirror of
https://github.com/ollama/ollama.git
synced 2025-06-17 15:50:58 +02:00
server/internal/client/ollama: confirm all chunksums were received (#9893)
If the chunksums response is missing a chunk, the client should fail the download. This changes the client to check that all bytes are accounted for in the chunksums response. It is possible there are overlaps or gaps in the chunksums response and so the size is not the only thing left to check, but this provides enough coverage for now. We may want to check that chunks are contiguous later.
This commit is contained in:
parent
da0e345200
commit
2ddacd7516
@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@ -213,12 +212,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,23 +481,26 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// TODO(bmizerany): fix this unbounded use of defer
|
||||||
defer chunked.Close()
|
defer chunked.Close()
|
||||||
|
|
||||||
var progress atomic.Int64
|
var progress atomic.Int64
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Bad chunksums response, update tracing
|
||||||
|
// clients and then bail.
|
||||||
t.update(l, progress.Load(), err)
|
t.update(l, progress.Load(), err)
|
||||||
break
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() {
|
||||||
|
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
err := func() error {
|
t.update(l, progress.Load(), err)
|
||||||
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -518,35 +512,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
// Count bytes towards
|
// Count bytes towards progress, as they
|
||||||
// progress, as they arrive, so
|
// arrive, so that our bytes piggyback other
|
||||||
// that our bytes piggyback
|
// chunk updates on completion.
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
//
|
||||||
// This tactic is enough to
|
// This tactic is enough to show "smooth"
|
||||||
// show "smooth" progress given
|
// progress given the current CLI client. In
|
||||||
// the current CLI client. In
|
// the near future, the server should report
|
||||||
// the near future, the server
|
// download rate since it knows better than a
|
||||||
// should report download rate
|
// client that is measuring rate based on
|
||||||
// since it knows better than
|
// wall-clock time-since-last-update.
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
body := &trackingReader{r: res.Body, n: &progress}
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
}
|
}
|
||||||
blobURL := res.Header.Get("Content-Location")
|
blobURL := res.Header.Get("Content-Location")
|
||||||
|
|
||||||
|
var size int64
|
||||||
s := bufio.NewScanner(res.Body)
|
s := bufio.NewScanner(res.Body)
|
||||||
s.Split(bufio.ScanWords)
|
s.Split(bufio.ScanWords)
|
||||||
for {
|
for {
|
||||||
if !s.Scan() {
|
if !s.Scan() {
|
||||||
if s.Err() != nil {
|
if s.Err() != nil {
|
||||||
yield(chunksum{}, s.Err())
|
yield(chunksum{}, s.Err())
|
||||||
|
} else if size != l.Size {
|
||||||
|
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size += chunk.Size()
|
||||||
|
if size > l.Size {
|
||||||
|
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
cs := chunksum{
|
cs := chunksum{
|
||||||
URL: blobURL,
|
URL: blobURL,
|
||||||
Chunk: chunk,
|
Chunk: chunk,
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPullChunksums(t *testing.T) {
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
content := "hello"
|
||||||
|
var chunksums string
|
||||||
|
contentDigest := func() blob.Digest {
|
||||||
|
return blob.DigestFromBytes(content)
|
||||||
|
}
|
||||||
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||||
|
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||||
|
w.Header().Set("Content-Location", loc)
|
||||||
|
io.WriteString(w, chunksums)
|
||||||
|
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||||
|
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected request: %v", r)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||||
|
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Logf("Update: %v %d %v", l, n, err)
|
||||||
|
mu.Lock()
|
||||||
|
reads = append(reads, n)
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||||
|
blob.DigestFromBytes("hel"),
|
||||||
|
blob.DigestFromBytes("lo"),
|
||||||
|
)
|
||||||
|
err := rc.Pull(ctx, "test")
|
||||||
|
check(err)
|
||||||
|
if !slices.Equal(reads, []int64{0, 3, 5}) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
|
||||||
|
}
|
||||||
|
|
||||||
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
check(err)
|
||||||
|
mg, err := rc.ResolveLocal("test")
|
||||||
|
check(err)
|
||||||
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
|
}
|
||||||
|
for i := range mg.Layers {
|
||||||
|
_, err = c.Get(mg.Layers[i].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing chunks
|
||||||
|
content = "llama"
|
||||||
|
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||||
|
err = rc.Pull(ctx, "missingchunks")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error because of missing chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user