diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 8af8812fb..56a2cc4fd 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -147,6 +147,7 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CGO_ENABLED: '1'
+ GOEXPERIMENT: 'synctest'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
diff --git a/.gitignore b/.gitignore
index 551abec87..3a2af0bd1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,7 +5,6 @@
.swp
dist
build
-ollama
.cache
*.exe
.idea
@@ -14,3 +13,4 @@ test_data
__debug_bin*
llama/build
llama/vendor
+/ollama
diff --git a/.golangci.yaml b/.golangci.yaml
index 9d59fd6c0..9bb9786a8 100644
--- a/.golangci.yaml
+++ b/.golangci.yaml
@@ -6,8 +6,6 @@ linters:
- bidichk
- bodyclose
- containedctx
- - contextcheck
- - errcheck
- gocheckcompilerdirectives
- gofmt
- gofumpt
@@ -23,10 +21,11 @@ linters:
- staticcheck
- tenv
- unconvert
- - unused
- - usestdlibvars
- wastedassign
- whitespace
+ disable:
+ - usestdlibvars
+ - errcheck
linters-settings:
staticcheck:
checks:
@@ -39,5 +38,4 @@ severity:
- gofmt
- goimports
- intrange
- - usestdlibvars
severity: info
diff --git a/server/internal/cache/blob/cache.go b/server/internal/cache/blob/cache.go
new file mode 100644
index 000000000..f0b0760f1
--- /dev/null
+++ b/server/internal/cache/blob/cache.go
@@ -0,0 +1,544 @@
+// Package blob implements a content-addressable disk cache for blobs and
+// manifests.
+package blob
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "io/fs"
+ "iter"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/ollama/ollama/server/internal/internal/names"
+)
+
+// Entry contains metadata about a blob in the cache.
+type Entry struct {
+ Digest Digest
+ Size int64
+ Time time.Time // when added to the cache
+}
+
+// DiskCache caches blobs and manifests on disk.
+//
+// The cache is rooted at a directory, which is created if it does not exist.
+//
+// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
+// "manifests" subdirectory. A example directory structure might look like:
+//
+//
/
+// blobs/
+// sha256- -
+// manifests/
+// /
+// /
+// /
+// -
+//
+// The cache is safe for concurrent use.
+//
+// Name casing is preserved in the cache, but is not significant when resolving
+// names. For example, "Foo" and "foo" are considered the same name.
+//
+// The cache is not safe for concurrent use. It guards concurrent writes, but
+// does not prevent duplicated effort. Because blobs are immutable, duplicate
+// writes should result in the same file being written to disk.
+type DiskCache struct {
+ // Dir specifies the top-level directory where blobs and manifest
+ // pointers are stored.
+ dir string
+ now func() time.Time
+
+ testHookBeforeFinalWrite func(f *os.File)
+}
+
+// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
+func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
+ return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
+}
+
+// Open opens a cache rooted at the given directory. If the directory does not
+// exist, it is created. If the directory is not a directory, an error is
+// returned.
+func Open(dir string) (*DiskCache, error) {
+ if dir == "" {
+ return nil, errors.New("blob: empty directory name")
+ }
+
+ info, err := os.Stat(dir)
+ if err == nil && !info.IsDir() {
+ return nil, fmt.Errorf("%q is not a directory", dir)
+ }
+ if err := os.MkdirAll(dir, 0o777); err != nil {
+ return nil, err
+ }
+
+ subdirs := []string{"blobs", "manifests"}
+ for _, subdir := range subdirs {
+ if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
+ return nil, err
+ }
+ }
+
+ // TODO(bmizerany): support shards
+ c := &DiskCache{
+ dir: dir,
+ now: time.Now,
+ }
+ return c, nil
+}
+
+func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, Digest{}, err
+ }
+ defer f.Close()
+
+ h := sha256.New()
+ r := io.TeeReader(f, h)
+ data, err = io.ReadAll(io.LimitReader(r, limit))
+ if err != nil {
+ return nil, Digest{}, err
+ }
+ var d Digest
+ h.Sum(d.sum[:0])
+ return data, d, nil
+}
+
+//lint:ignore U1000 used for debugging purposes as needed in tests
+var debug = false
+
+// debugger returns a function that can be used to add a step to the error message.
+// The error message will be a list of steps that were taken before the error occurred.
+// The steps are added in the order they are called.
+//
+// To set the error message, call the returned function with an empty string.
+//
+//lint:ignore U1000 used for debugging purposes as needed in tests
+func debugger(err *error) func(step string) {
+ if !debug {
+ return func(string) {}
+ }
+ var steps []string
+ return func(step string) {
+ if step == "" && *err != nil {
+ *err = fmt.Errorf("%q: %w", steps, *err)
+ return
+ }
+ steps = append(steps, step)
+ if len(steps) > 100 {
+ // shift hints in case of a bug that causes a lot of hints
+ copy(steps, steps[1:])
+ steps = steps[:100]
+ }
+ }
+}
+
+// Resolve resolves a name to a digest. The name is expected to
+// be in either of the following forms:
+//
+// @
+//
+//
+//
+// If a digest is provided, it is returned as is and nothing else happens.
+//
+// If a name is provided for a manifest that exists in the cache, the digest
+// of the manifest is returned. If there is no manifest in the cache, it
+// returns [fs.ErrNotExist].
+//
+// To cover the case where a manifest may change without the cache knowing
+// (e.g. it was reformatted or modified by hand), the manifest data read and
+// hashed is passed to a PutBytes call to ensure that the manifest is in the
+// blob store. This is done to ensure that future calls to [Get] succeed in
+// these cases.
+//
+// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
+func (c *DiskCache) Resolve(name string) (Digest, error) {
+ name, digest := splitNameDigest(name)
+ if digest != "" {
+ return ParseDigest(digest)
+ }
+
+ // We want to address manifests files by digest using Get. That requires
+ // them to be blobs. This cannot be directly accomplished by looking in
+ // the blob store because manifests can change without Ollama knowing
+ // (e.g. a user modifies a manifests by hand then pushes it to update
+ // their model). We also need to support the blob caches inherited from
+ // older versions of Ollama, which do not store manifests in the blob
+ // store, so for these cases, we need to handle adding the manifests to
+ // the blob store, just in time.
+ //
+ // So now we read the manifests file, hash it, and copy it to the blob
+ // store if it's not already there.
+ //
+ // This should be cheap because manifests are small, and accessed
+ // infrequently.
+ file, err := c.manifestPath(name)
+ if err != nil {
+ return Digest{}, err
+ }
+
+ data, d, err := readAndSum(file, 1<<20)
+ if err != nil {
+ return Digest{}, err
+ }
+
+ // Ideally we'd read the "manifest" file as a manifest to the blob file,
+ // but we are not changing this yet, so copy the manifest to the blob
+ // store so it can be addressed by digest subsequent calls to Get.
+ if err := PutBytes(c, d, data); err != nil {
+ return Digest{}, err
+ }
+ return d, nil
+}
+
+// Put writes a new blob to the cache, identified by its digest. The operation
+// reads content from r, which must precisely match both the specified size and
+// digest.
+//
+// Concurrent write safety is achieved through file locking. The implementation
+// guarantees write integrity by enforcing size limits and content validation
+// before allowing the file to reach its final state.
+func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
+ return c.copyNamedFile(c.GetFile(d), r, d, size)
+}
+
+// Import imports a blob from the provided reader into the cache. It reads the
+// entire content of the reader, calculates its digest, and stores it in the
+// cache.
+//
+// Import should be considered unsafe for use with untrusted data, such as data
+// read from a network. The caller is responsible for ensuring the integrity of
+// the data being imported.
+func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
+ // users that want to change the temp dir can set TEMPDIR.
+ f, err := os.CreateTemp("", "blob-")
+ if err != nil {
+ return Digest{}, err
+ }
+ defer os.Remove(f.Name())
+
+ // Copy the blob to a temporary file.
+ h := sha256.New()
+ r = io.TeeReader(r, h)
+ n, err := io.Copy(f, r)
+ if err != nil {
+ return Digest{}, err
+ }
+ if n != size {
+ return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
+ }
+
+ // Check the digest.
+ var d Digest
+ h.Sum(d.sum[:0])
+ if err := f.Close(); err != nil {
+ return Digest{}, err
+ }
+ name := c.GetFile(d)
+ // Rename the temporary file to the final file.
+ if err := os.Rename(f.Name(), name); err != nil {
+ return Digest{}, err
+ }
+ os.Chtimes(name, c.now(), c.now()) // mainly for tests
+ return d, nil
+}
+
+// Get retrieves a blob from the cache using the provided digest. The operation
+// fails if the digest is malformed or if any errors occur during blob
+// retrieval.
+func (c *DiskCache) Get(d Digest) (Entry, error) {
+ name := c.GetFile(d)
+ info, err := os.Stat(name)
+ if err != nil {
+ return Entry{}, err
+ }
+ if info.Size() == 0 {
+ return Entry{}, fs.ErrNotExist
+ }
+ return Entry{
+ Digest: d,
+ Size: info.Size(),
+ Time: info.ModTime(),
+ }, nil
+}
+
+// Link creates a symbolic reference in the cache that maps the provided name
+// to a blob identified by its digest, making it retrievable by name using
+// [Resolve].
+//
+// It returns an error if either the name or digest is invalid, or if link
+// creation encounters any issues.
+func (c *DiskCache) Link(name string, d Digest) error {
+ manifest, err := c.manifestPath(name)
+ if err != nil {
+ return err
+ }
+ f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ // TODO(bmizerany): test this happens only if the blob was found to
+ // avoid leaving debris
+ if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
+ return err
+ }
+
+ info, err := f.Stat()
+ if err != nil {
+ return err
+ }
+
+ // Copy manifest to cache directory.
+ return c.copyNamedFile(manifest, f, d, info.Size())
+}
+
+// Unlink removes the any link for name. If the link does not exist, nothing
+// happens, and no error is returned.
+//
+// It returns an error if the name is invalid or if the link removal encounters
+// any issues.
+func (c *DiskCache) Unlink(name string) error {
+ manifest, err := c.manifestPath(name)
+ if err != nil {
+ return err
+ }
+ err = os.Remove(manifest)
+ if errors.Is(err, fs.ErrNotExist) {
+ return nil
+ }
+ return err
+}
+
+// GetFile returns the absolute path to the file, in the cache, for the given
+// digest. It does not check if the file exists.
+//
+// The returned path should not be stored, used outside the lifetime of the
+// cache, or interpreted in any way.
+func (c *DiskCache) GetFile(d Digest) string {
+ filename := fmt.Sprintf("sha256-%x", d.sum)
+ return absJoin(c.dir, "blobs", filename)
+}
+
+// Links returns a sequence of links in the cache in lexical order.
+func (c *DiskCache) Links() iter.Seq2[string, error] {
+ return func(yield func(string, error) bool) {
+ for path, err := range c.links() {
+ if err != nil {
+ yield("", err)
+ return
+ }
+ if !yield(pathToName(path), nil) {
+ return
+ }
+ }
+ }
+}
+
+// pathToName converts a path to a name. It is the inverse of nameToPath. The
+// path is assumed to be in filepath.ToSlash format.
+func pathToName(s string) string {
+ s = strings.TrimPrefix(s, "manifests/")
+ rr := []rune(s)
+ for i := len(rr) - 1; i > 0; i-- {
+ if rr[i] == '/' {
+ rr[i] = ':'
+ return string(rr)
+ }
+ }
+ return s
+}
+
+// manifestPath finds the first manifest file on disk that matches the given
+// name using a case-insensitive comparison. If no manifest file is found, it
+// returns the path where the manifest file would be if it existed.
+//
+// If two manifest files exists on disk that match the given name using a
+// case-insensitive comparison, the one that sorts first, lexically, is
+// returned.
+func (c *DiskCache) manifestPath(name string) (string, error) {
+ np, err := nameToPath(name)
+ if err != nil {
+ return "", err
+ }
+
+ maybe := filepath.Join("manifests", np)
+ for l, err := range c.links() {
+ if err != nil {
+ return "", err
+ }
+ if strings.EqualFold(maybe, l) {
+ return filepath.Join(c.dir, l), nil
+ }
+ }
+ return filepath.Join(c.dir, maybe), nil
+}
+
+// links returns a sequence of links in the cache in lexical order.
+func (c *DiskCache) links() iter.Seq2[string, error] {
+ // TODO(bmizerany): reuse empty dirnames if exist
+ return func(yield func(string, error) bool) {
+ fsys := os.DirFS(c.dir)
+ manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
+ if err != nil {
+ yield("", err)
+ return
+ }
+ for _, manifest := range manifests {
+ if !yield(manifest, nil) {
+ return
+ }
+ }
+ }
+}
+
+type checkWriter struct {
+ d Digest
+ size int64
+ n int64
+ h hash.Hash
+ f *os.File
+ err error
+
+ testHookBeforeFinalWrite func(*os.File)
+}
+
+func (w *checkWriter) seterr(err error) error {
+ if w.err == nil {
+ w.err = err
+ }
+ return err
+}
+
+// Write writes p to the underlying hash and writer. The last write to the
+// underlying writer is guaranteed to be the last byte of p as verified by the
+// hash.
+func (w *checkWriter) Write(p []byte) (int, error) {
+ _, err := w.h.Write(p)
+ if err != nil {
+ return 0, w.seterr(err)
+ }
+ nextSize := w.n + int64(len(p))
+ if nextSize == w.size {
+ // last write. check hash.
+ sum := w.h.Sum(nil)
+ if !bytes.Equal(sum, w.d.sum[:]) {
+ return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
+ }
+ if w.testHookBeforeFinalWrite != nil {
+ w.testHookBeforeFinalWrite(w.f)
+ }
+ }
+ if nextSize > w.size {
+ return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
+ }
+ n, err := w.f.Write(p)
+ w.n += int64(n)
+ return n, w.seterr(err)
+}
+
+// copyNamedFile copies file into name, expecting it to have the given Digest
+// and size, if that file is not present already.
+func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
+ info, err := os.Stat(name)
+ if err == nil && info.Size() == size {
+ // File already exists with correct size. This is good enough.
+ // We can skip expensive hash checks.
+ //
+ // TODO: Do the hash check, but give caller a way to skip it.
+ return nil
+ }
+
+ // Copy file to cache directory.
+ mode := os.O_RDWR | os.O_CREATE
+ if err == nil && info.Size() > size { // shouldn't happen but fix in case
+ mode |= os.O_TRUNC
+ }
+ f, err := os.OpenFile(name, mode, 0o666)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ if size == 0 {
+ // File now exists with correct size.
+ // Only one possible zero-length file, so contents are OK too.
+ // Early return here makes sure there's a "last byte" for code below.
+ return nil
+ }
+
+ // From here on, if any of the I/O writing the file fails,
+ // we make a best-effort attempt to truncate the file f
+ // before returning, to avoid leaving bad bytes in the file.
+
+ // Copy file to f, but also into h to double-check hash.
+ cw := &checkWriter{
+ d: out,
+ size: size,
+ h: sha256.New(),
+ f: f,
+ testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
+ }
+ n, err := io.Copy(cw, file)
+ if err != nil {
+ f.Truncate(0)
+ return err
+ }
+ if n < size {
+ f.Truncate(0)
+ return io.ErrUnexpectedEOF
+ }
+
+ if err := f.Close(); err != nil {
+ // Data might not have been written,
+ // but file may look like it is the right size.
+ // To be extra careful, remove cached file.
+ os.Remove(name)
+ return err
+ }
+ os.Chtimes(name, c.now(), c.now()) // mainly for tests
+
+ return nil
+}
+
+func splitNameDigest(s string) (name, digest string) {
+ i := strings.LastIndexByte(s, '@')
+ if i < 0 {
+ return s, ""
+ }
+ return s[:i], s[i+1:]
+}
+
+var errInvalidName = errors.New("invalid name")
+
+func nameToPath(name string) (_ string, err error) {
+ if strings.Contains(name, "@") {
+ // TODO(bmizerany): HACK: Fix names.Parse to validate.
+ // TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
+ return "", errInvalidName
+ }
+ n := names.Parse(name)
+ if !n.IsFullyQualified() {
+ return "", errInvalidName
+ }
+ return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
+}
+
+func absJoin(pp ...string) string {
+ abs, err := filepath.Abs(filepath.Join(pp...))
+ if err != nil {
+ // Likely a bug bug or a bad OS problem. Just panic.
+ panic(err)
+ }
+ return abs
+}
diff --git a/server/internal/cache/blob/cache_test.go b/server/internal/cache/blob/cache_test.go
new file mode 100644
index 000000000..704542ea3
--- /dev/null
+++ b/server/internal/cache/blob/cache_test.go
@@ -0,0 +1,685 @@
+package blob
+
+import (
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/ollama/ollama/server/internal/internal/testutil"
+)
+
+func init() {
+ debug = true
+}
+
+var epoch = func() time.Time {
+ d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
+ if d.IsZero() {
+ panic("time zero")
+ }
+ return d
+}()
+
+func TestOpenErrors(t *testing.T) {
+ exe, err := os.Executable()
+ if err != nil {
+ panic(err)
+ }
+
+ cases := []struct {
+ dir string
+ err string
+ }{
+ {t.TempDir(), ""},
+ {"", "empty directory name"},
+ {exe, "not a directory"},
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.dir, func(t *testing.T) {
+ _, err := Open(tt.dir)
+ if tt.err == "" {
+ if err != nil {
+ t.Fatal(err)
+ }
+ return
+ }
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), tt.err) {
+ t.Fatalf("err = %v, want %q", err, tt.err)
+ }
+ })
+ }
+}
+
+func TestGetFile(t *testing.T) {
+ t.Chdir(t.TempDir())
+
+ c, err := Open(".")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ d := mkdigest("1")
+ got := c.GetFile(d)
+ cleaned := filepath.Clean(got)
+ if cleaned != got {
+ t.Fatalf("got is unclean: %q", got)
+ }
+ if !filepath.IsAbs(got) {
+ t.Fatal("got is not absolute")
+ }
+ abs, _ := filepath.Abs(c.dir)
+ if !strings.HasPrefix(got, abs) {
+ t.Fatalf("got is not local to %q", c.dir)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ c, err := Open(t.TempDir())
+ if err != nil {
+ t.Fatal(err)
+ }
+ now := epoch
+ c.now = func() time.Time { return now }
+
+ checkEntry := entryChecker(t, c)
+ checkFailed := func(err error) {
+ if err == nil {
+ t.Helper()
+ t.Fatal("expected error")
+ }
+ }
+
+ _, err = c.Resolve("invalid")
+ checkFailed(err)
+
+ _, err = c.Resolve("h/n/m:t")
+ checkFailed(err)
+
+ dx := mkdigest("x")
+
+ d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if d != dx {
+ t.Fatalf("d = %v, want %v", d, dx)
+ }
+
+ _, err = c.Get(Digest{})
+ checkFailed(err)
+
+ // not committed yet
+ _, err = c.Get(dx)
+ checkFailed(err)
+
+ err = PutBytes(c, dx, "!")
+ checkFailed(err)
+
+ err = PutBytes(c, dx, "x")
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkEntry(dx, 1, now)
+
+ t0 := now
+ now = now.Add(1*time.Hour + 1*time.Minute)
+ err = PutBytes(c, dx, "x")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check not updated
+ checkEntry(dx, 1, t0)
+}
+
+type sleepFunc func(d time.Duration) time.Time
+
+func openTester(t *testing.T) (*DiskCache, sleepFunc) {
+ t.Helper()
+ c, err := Open(t.TempDir())
+ if err != nil {
+ t.Fatal(err)
+ }
+ now := epoch
+ c.now = func() time.Time { return now }
+ return c, func(d time.Duration) time.Time {
+ now = now.Add(d)
+ return now
+ }
+}
+
+func TestManifestPath(t *testing.T) {
+ check := testutil.Checker(t)
+
+ c, sleep := openTester(t)
+
+ d1 := mkdigest("1")
+ err := PutBytes(c, d1, "1")
+ check(err)
+
+ err = c.Link("h/n/m:t", d1)
+ check(err)
+
+ t0 := sleep(0)
+ sleep(1 * time.Hour)
+ err = c.Link("h/n/m:t", d1) // nop expected
+ check(err)
+
+ file := must(c.manifestPath("h/n/m:t"))
+ info, err := os.Stat(file)
+ check(err)
+ testutil.CheckTime(t, info.ModTime(), t0)
+}
+
+func TestManifestExistsWithoutBlob(t *testing.T) {
+ t.Chdir(t.TempDir())
+
+ check := testutil.Checker(t)
+
+ c, err := Open(".")
+ check(err)
+
+ checkEntry := entryChecker(t, c)
+
+ man := must(c.manifestPath("h/n/m:t"))
+ os.MkdirAll(filepath.Dir(man), 0o777)
+ testutil.WriteFile(t, man, "1")
+
+ got, err := c.Resolve("h/n/m:t")
+ check(err)
+
+ want := mkdigest("1")
+ if got != want {
+ t.Fatalf("got = %v, want %v", got, want)
+ }
+
+ e, err := c.Get(got)
+ check(err)
+ checkEntry(got, 1, e.Time)
+}
+
+func TestPut(t *testing.T) {
+ c, sleep := openTester(t)
+
+ check := testutil.Checker(t)
+ checkEntry := entryChecker(t, c)
+
+ d := mkdigest("hello, world")
+ err := PutBytes(c, d, "hello")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+
+ got, err := c.Get(d)
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Fatalf("expected error, got %v", got)
+ }
+
+ // Put a valid blob
+ err = PutBytes(c, d, "hello, world")
+ check(err)
+ checkEntry(d, 12, sleep(0))
+
+ // Put a blob with content that does not hash to the digest
+ err = PutBytes(c, d, "hello")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+
+ // Put the valid blob back and check it
+ err = PutBytes(c, d, "hello, world")
+ check(err)
+ checkEntry(d, 12, sleep(0))
+
+ // Put a blob that errors during Read
+ err = c.Put(d, &errOnBangReader{s: "!"}, 1)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+
+ // Put valid blob back and check it
+ err = PutBytes(c, d, "hello, world")
+ check(err)
+ checkEntry(d, 12, sleep(0))
+
+ // Put a blob with mismatched size
+ err = c.Put(d, strings.NewReader("hello, world"), 11)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+
+ // Final byte does not match the digest (testing commit phase)
+ err = PutBytes(c, d, "hello, world$")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+
+ reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
+ // change mode to read-only
+ f.Truncate(0)
+ f.Chmod(0o400)
+ f.Close()
+ f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { f1.Close() })
+ *f = *f1
+ })
+ defer reset()
+
+ err = PutBytes(c, d, "hello, world")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+ reset()
+}
+
+func TestImport(t *testing.T) {
+ c, _ := openTester(t)
+
+ checkEntry := entryChecker(t, c)
+
+ want := mkdigest("x")
+ got, err := c.Import(strings.NewReader("x"), 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want != got {
+ t.Fatalf("digest = %v, want %v", got, want)
+ }
+ checkEntry(want, 1, epoch)
+
+ got, err = c.Import(strings.NewReader("x"), 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want != got {
+ t.Fatalf("digest = %v, want %v", got, want)
+ }
+ checkEntry(want, 1, epoch)
+}
+
+func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
+ old := c.testHookBeforeFinalWrite
+ c.testHookBeforeFinalWrite = h
+ return func() { c.testHookBeforeFinalWrite = old }
+}
+
+func TestPutGetZero(t *testing.T) {
+ c, sleep := openTester(t)
+
+ check := testutil.Checker(t)
+ checkEntry := entryChecker(t, c)
+
+ d := mkdigest("x")
+ err := PutBytes(c, d, "x")
+ check(err)
+ checkEntry(d, 1, sleep(0))
+
+ err = os.Truncate(c.GetFile(d), 0)
+ check(err)
+
+ _, err = c.Get(d)
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Fatalf("err = %v, want fs.ErrNotExist", err)
+ }
+}
+
+func TestPutZero(t *testing.T) {
+ c, _ := openTester(t)
+ d := mkdigest("x")
+ err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
+ testutil.Check(t, err)
+ checkNotExists(t, c, d)
+}
+
+func TestCommit(t *testing.T) {
+ check := testutil.Checker(t)
+
+ c, err := Open(t.TempDir())
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkEntry := entryChecker(t, c)
+
+ now := epoch
+ c.now = func() time.Time { return now }
+
+ d1 := mkdigest("1")
+ err = c.Link("h/n/m:t", d1)
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Fatalf("err = %v, want fs.ErrNotExist", err)
+ }
+
+ err = PutBytes(c, d1, "1")
+ check(err)
+
+ err = c.Link("h/n/m:t", d1)
+ check(err)
+
+ got, err := c.Resolve("h/n/m:t")
+ check(err)
+ if got != d1 {
+ t.Fatalf("d = %v, want %v", got, d1)
+ }
+
+ // commit again, more than 1 byte
+ d2 := mkdigest("22")
+ err = PutBytes(c, d2, "22")
+ check(err)
+ err = c.Link("h/n/m:t", d2)
+ check(err)
+ checkEntry(d2, 2, now)
+
+ filename := must(c.manifestPath("h/n/m:t"))
+ data, err := os.ReadFile(filename)
+ check(err)
+ if string(data) != "22" {
+ t.Fatalf("data = %q, want %q", data, "22")
+ }
+
+ t0 := now
+ now = now.Add(1 * time.Hour)
+ err = c.Link("h/n/m:t", d2) // same contents; nop
+ check(err)
+ info, err := os.Stat(filename)
+ check(err)
+ testutil.CheckTime(t, info.ModTime(), t0)
+}
+
+func TestManifestInvalidBlob(t *testing.T) {
+ c, _ := openTester(t)
+ d := mkdigest("1")
+ err := c.Link("h/n/m:t", d)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ checkNotExists(t, c, d)
+
+ err = PutBytes(c, d, "1")
+ testutil.Check(t, err)
+ err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = c.Link("h/n/m:t", d)
+ if !strings.Contains(err.Error(), "underfoot") {
+ t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
+ }
+}
+
+func TestManifestNameReuse(t *testing.T) {
+ t.Run("case-insensitive", func(t *testing.T) {
+ // This should run on all file system types.
+ testManifestNameReuse(t)
+ })
+ t.Run("case-sensitive", func(t *testing.T) {
+ useCaseInsensitiveTempDir(t)
+ testManifestNameReuse(t)
+ })
+}
+
+func testManifestNameReuse(t *testing.T) {
+ check := testutil.Checker(t)
+
+ c, _ := openTester(t)
+
+ d1 := mkdigest("1")
+ err := PutBytes(c, d1, "1")
+ check(err)
+ err = c.Link("h/n/m:t", d1)
+ check(err)
+
+ d2 := mkdigest("22")
+ err = PutBytes(c, d2, "22")
+ check(err)
+ err = c.Link("H/N/M:T", d2)
+ check(err)
+
+ var g [2]Digest
+ g[0], err = c.Resolve("h/n/m:t")
+ check(err)
+ g[1], err = c.Resolve("H/N/M:T")
+ check(err)
+
+ w := [2]Digest{d2, d2}
+ if g != w {
+ t.Fatalf("g = %v, want %v", g, w)
+ }
+
+ var got []string
+ for l, err := range c.links() {
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = append(got, l)
+ }
+ want := []string{"manifests/h/n/m/t"}
+ if !slices.Equal(got, want) {
+ t.Fatalf("got = %v, want %v", got, want)
+ }
+
+ // relink with different case
+ err = c.Unlink("h/n/m:t")
+ check(err)
+ err = c.Link("h/n/m:T", d1)
+ check(err)
+
+ got = got[:0]
+ for l, err := range c.links() {
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = append(got, l)
+ }
+
+ // we should have only one link that is same case as the last link
+ want = []string{"manifests/h/n/m/T"}
+ if !slices.Equal(got, want) {
+ t.Fatalf("got = %v, want %v", got, want)
+ }
+}
+
+func TestManifestFile(t *testing.T) {
+ cases := []struct {
+ in string
+ want string
+ }{
+ {"", ""},
+
+ // valid names
+ {"h/n/m:t", "/manifests/h/n/m/t"},
+ {"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
+
+ {"%/%/%/%", ""},
+
+ // already a path
+ {"h/n/m/t", ""},
+
+ // refs are not names
+ {"h/n/m:t@sha256-1", ""},
+ {"m@sha256-1", ""},
+ {"n/m:t@sha256-1", ""},
+ }
+
+ c, _ := openTester(t)
+ for _, tt := range cases {
+ t.Run(tt.in, func(t *testing.T) {
+ got, err := c.manifestPath(tt.in)
+ if err != nil && tt.want != "" {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if err == nil && tt.want == "" {
+ t.Fatalf("expected error")
+ }
+ dir := filepath.ToSlash(c.dir)
+ got = filepath.ToSlash(got)
+ got = strings.TrimPrefix(got, dir)
+ if got != tt.want {
+ t.Fatalf("got = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestNames(t *testing.T) {
+ c, _ := openTester(t)
+ check := testutil.Checker(t)
+
+ check(PutBytes(c, mkdigest("1"), "1"))
+ check(PutBytes(c, mkdigest("2"), "2"))
+
+ check(c.Link("h/n/m:t", mkdigest("1")))
+ check(c.Link("h/n/m:u", mkdigest("2")))
+
+ var got []string
+ for l, err := range c.Links() {
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = append(got, l)
+ }
+ want := []string{"h/n/m:t", "h/n/m:u"}
+ if !slices.Equal(got, want) {
+ t.Fatalf("got = %v, want %v", got, want)
+ }
+}
+
+func mkdigest(s string) Digest {
+ return Digest{sha256.Sum256([]byte(s))}
+}
+
+func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
+ t.Helper()
+ _, err := c.Get(d)
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Fatalf("err = %v, want fs.ErrNotExist", err)
+ }
+}
+
+func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
+ t.Helper()
+ return func(d Digest, size int64, mod time.Time) {
+ t.Helper()
+ t.Run("checkEntry:"+d.String(), func(t *testing.T) {
+ t.Helper()
+
+ defer func() {
+ if t.Failed() {
+ dumpCacheContents(t, c)
+ }
+ }()
+
+ e, err := c.Get(d)
+ if size == 0 && errors.Is(err, fs.ErrNotExist) {
+ err = nil
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if e.Digest != d {
+ t.Errorf("e.Digest = %v, want %v", e.Digest, d)
+ }
+ if e.Size != size {
+ t.Fatalf("e.Size = %v, want %v", e.Size, size)
+ }
+
+ testutil.CheckTime(t, e.Time, mod)
+ info, err := os.Stat(c.GetFile(d))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if info.Size() != size {
+ t.Fatalf("info.Size = %v, want %v", info.Size(), size)
+ }
+ testutil.CheckTime(t, info.ModTime(), mod)
+ })
+ }
+}
+
+func must[T any](v T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+func TestNameToPath(t *testing.T) {
+ _, err := nameToPath("h/n/m:t")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+type errOnBangReader struct {
+ s string
+ n int
+}
+
+func (e *errOnBangReader) Read(p []byte) (int, error) {
+ if len(p) < 1 {
+ return 0, io.ErrShortBuffer
+ }
+ if e.n >= len(p) {
+ return 0, io.EOF
+ }
+ if e.s[e.n] == '!' {
+ return 0, errors.New("bang")
+ }
+ p[0] = e.s[e.n]
+ e.n++
+ return 1, nil
+}
+
+func dumpCacheContents(t *testing.T, c *DiskCache) {
+ t.Helper()
+
+ var b strings.Builder
+ fsys := os.DirFS(c.dir)
+ fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
+ t.Helper()
+
+ if err != nil {
+ return err
+ }
+ info, err := d.Info()
+ if err != nil {
+ return err
+ }
+
+ // Format like ls:
+ //
+ // ; ls -la
+ // drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
+ // drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
+
+ fmt.Fprintf(&b, " %s % 4d %s %s\n",
+ info.Mode(),
+ info.Size(),
+ info.ModTime().Format("Jan 2 15:04"),
+ path,
+ )
+ return nil
+ })
+ t.Log()
+ t.Logf("cache contents:\n%s", b.String())
+}
diff --git a/server/internal/cache/blob/casecheck_test.go b/server/internal/cache/blob/casecheck_test.go
new file mode 100644
index 000000000..f0842ef91
--- /dev/null
+++ b/server/internal/cache/blob/casecheck_test.go
@@ -0,0 +1,93 @@
+package blob
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func isCaseSensitive(dir string) bool {
+ defer func() {
+ os.Remove(filepath.Join(dir, "_casecheck"))
+ }()
+
+ exists := func(file string) bool {
+ _, err := os.Stat(file)
+ return err == nil
+ }
+
+ file := filepath.Join(dir, "_casecheck")
+ FILE := filepath.Join(dir, "_CASECHECK")
+ if exists(file) || exists(FILE) {
+ panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
+ }
+
+ err := os.WriteFile(file, nil, 0o666)
+ if err != nil {
+ panic(err)
+ }
+
+ return !exists(FILE)
+}
+
+func isCI() bool {
+ return os.Getenv("CI") != ""
+}
+
+const volumeHint = `
+
+ Unable to locate case-insensitive TMPDIR on darwin.
+
+ To run tests, create the case-insensitive volume /Volumes/data:
+
+ $ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
+
+ or run with:
+
+ CI=1 go test ./...
+
+`
+
+// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
+// can find one, otherwise it skips the test if the CI environment variable is
+// set, or GOOS is not darwin.
+func useCaseInsensitiveTempDir(t *testing.T) bool {
+ if isCaseSensitive(os.TempDir()) {
+ // Use the default temp dir if it is already case-sensitive.
+ return true
+ }
+ if runtime.GOOS == "darwin" {
+ // If darwin, check for the special case-sensitive volume and
+ // use it if available.
+ const volume = "/Volumes/data"
+ _, err := os.Stat(volume)
+ if err == nil {
+ tmpdir := filepath.Join(volume, "tmp")
+ os.MkdirAll(tmpdir, 0o700)
+ t.Setenv("TMPDIR", tmpdir)
+ return true
+ }
+ if isCI() {
+ // Special case darwin in CI; it is not case-sensitive
+ // by default, and we will be testing other platforms
+ // that are case-sensitive, so we'll have the test
+ // being skipped covered there.
+ t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
+ }
+ }
+
+ if !isCI() {
+ // Require devs to always tests with a case-insensitive TMPDIR.
+
+ // TODO(bmizerany): Print platform-specific instructions or
+ // link to docs on that topic.
+ lines := strings.Split(volumeHint, "\n")
+ for _, line := range lines {
+ t.Log(line)
+ }
+ }
+ return false
+}
diff --git a/server/internal/cache/blob/digest.go b/server/internal/cache/blob/digest.go
new file mode 100644
index 000000000..723ba222c
--- /dev/null
+++ b/server/internal/cache/blob/digest.go
@@ -0,0 +1,95 @@
+package blob
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "slices"
+ "strings"
+)
+
+var ErrInvalidDigest = errors.New("invalid digest")
+
+// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
+//
+// It is comparable and can be used as a map key.
+type Digest struct {
+ sum [32]byte
+}
+
+// ParseDigest parses a digest from a string. If the string is not a valid
+// digest, a call to the returned digest's IsValid method will return false.
+//
+// The input string may be in one of two forms:
+//
+// - ("sha256-"), where is a 64-character hexadecimal string.
+// - ("sha256:"), where is a 64-character hexadecimal string.
+//
+// The [Digest.String] method will return the canonical form of the
+// digest, "sha256:".
+func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
+ s := string(v)
+ i := strings.IndexAny(s, ":-")
+ var zero Digest
+ if i < 0 {
+ return zero, ErrInvalidDigest
+ }
+
+ prefix, sum := s[:i], s[i+1:]
+ if prefix != "sha256" || len(sum) != 64 {
+ return zero, ErrInvalidDigest
+ }
+
+ var d Digest
+ _, err := hex.Decode(d.sum[:], []byte(sum))
+ if err != nil {
+ return zero, ErrInvalidDigest
+ }
+ return d, nil
+}
+
+func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
+ return Digest{sha256.Sum256([]byte(v))}
+}
+
+// String returns the string representation of the digest in the conventional
+// form "sha256:".
+func (d Digest) String() string {
+ return fmt.Sprintf("sha256:%x", d.sum[:])
+}
+
+func (d Digest) Short() string {
+ return fmt.Sprintf("%x", d.sum[:4])
+}
+
+func (d Digest) Compare(other Digest) int {
+ return slices.Compare(d.sum[:], other.sum[:])
+}
+
+// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
+// of some content.
+func (d Digest) IsValid() bool {
+ return d != (Digest{})
+}
+
+// MarshalText implements the encoding.TextMarshaler interface. It returns an
+// error if [Digest.IsValid] returns false.
+func (d Digest) MarshalText() ([]byte, error) {
+ return []byte(d.String()), nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
+// works for a zero digest. If [Digest.IsValid] returns true, it returns an
+// error.
+func (d *Digest) UnmarshalText(text []byte) error {
+ if *d != (Digest{}) {
+ return errors.New("digest: illegal UnmarshalText on valid digest")
+ }
+ v, err := ParseDigest(string(text))
+ if err != nil {
+ return err
+ }
+ *d = v
+ return nil
+}
diff --git a/server/internal/cache/blob/digest_test.go b/server/internal/cache/blob/digest_test.go
new file mode 100644
index 000000000..c96ad383b
--- /dev/null
+++ b/server/internal/cache/blob/digest_test.go
@@ -0,0 +1,63 @@
+package blob
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestParseDigest(t *testing.T) {
+ cases := []struct {
+ in string
+ valid bool
+ }{
+ {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
+ {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
+
+ // too short
+ {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
+ {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
+
+ // too long
+ {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
+ {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
+
+ // invalid prefix
+ {"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
+ {"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
+ {"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
+
+ // invalid hex
+ {"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
+ {"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
+ }
+
+ for _, tt := range cases {
+ got, err := ParseDigest(tt.in)
+ if tt.valid && err != nil {
+ t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
+ }
+ want := "sha256:" + tt.in[7:]
+ if tt.valid && got.String() != want {
+ t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
+ }
+ }
+}
+
+func TestDigestMarshalText(t *testing.T) {
+ const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
+ var d Digest
+ if err := json.Unmarshal([]byte(s), &d); err != nil {
+ t.Errorf("json.Unmarshal: %v", err)
+ }
+ out, err := json.Marshal(d)
+ if err != nil {
+ t.Errorf("json.Marshal: %v", err)
+ }
+ want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
+ if string(out) != want {
+ t.Errorf("json.Marshal: got %s, want %s", out, want)
+ }
+ if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
+ t.Errorf("json.Unmarshal: expected error")
+ }
+}
diff --git a/server/internal/chunks/chunks.go b/server/internal/chunks/chunks.go
new file mode 100644
index 000000000..7eb7a6c17
--- /dev/null
+++ b/server/internal/chunks/chunks.go
@@ -0,0 +1,78 @@
+package chunks
+
+import (
+ "fmt"
+ "iter"
+ "strconv"
+ "strings"
+)
+
+type Chunk struct {
+ Start, End int64
+}
+
+func New(start, end int64) Chunk {
+ return Chunk{start, end}
+}
+
+// ParseRange parses a string in the form "unit=range" where unit is a string
+// and range is a string in the form "start-end". It returns the unit and the
+// range as a Chunk.
+func ParseRange(s string) (unit string, _ Chunk, _ error) {
+ unit, r, _ := strings.Cut(s, "=")
+ if r == "" {
+ return unit, Chunk{}, nil
+ }
+ c, err := Parse(r)
+ if err != nil {
+ return "", Chunk{}, err
+ }
+ return unit, c, err
+}
+
+// Parse parses a string in the form "start-end" and returns the Chunk.
+func Parse(s string) (Chunk, error) {
+ startStr, endStr, _ := strings.Cut(s, "-")
+ start, err := strconv.ParseInt(startStr, 10, 64)
+ if err != nil {
+ return Chunk{}, fmt.Errorf("invalid start: %v", err)
+ }
+ end, err := strconv.ParseInt(endStr, 10, 64)
+ if err != nil {
+ return Chunk{}, fmt.Errorf("invalid end: %v", err)
+ }
+ if start > end {
+ return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
+ }
+ return Chunk{start, end}, nil
+}
+
+// Of returns a sequence of contiguous Chunks of size chunkSize that cover
+// the range [0, size), in order.
+func Of(size, chunkSize int64) iter.Seq[Chunk] {
+ return func(yield func(Chunk) bool) {
+ for start := int64(0); start < size; start += chunkSize {
+ end := min(start+chunkSize-1, size-1)
+ if !yield(Chunk{start, end}) {
+ break
+ }
+ }
+ }
+}
+
+// Count returns the number of Chunks of size chunkSize needed to cover the
+// range [0, size).
+func Count(size, chunkSize int64) int64 {
+ return (size + chunkSize - 1) / chunkSize
+}
+
+// Size returns end minus start plus one.
+func (c Chunk) Size() int64 {
+ return c.End - c.Start + 1
+}
+
+// String returns the string representation of the Chunk in the form
+// "{start}-{end}".
+func (c Chunk) String() string {
+ return fmt.Sprintf("%d-%d", c.Start, c.End)
+}
diff --git a/server/internal/chunks/chunks_test.go b/server/internal/chunks/chunks_test.go
new file mode 100644
index 000000000..c23e0de8e
--- /dev/null
+++ b/server/internal/chunks/chunks_test.go
@@ -0,0 +1,65 @@
+package chunks
+
+import (
+ "slices"
+ "testing"
+)
+
+func TestOf(t *testing.T) {
+ cases := []struct {
+ total int64
+ chunkSize int64
+ want []Chunk
+ }{
+ {0, 1, nil},
+ {1, 1, []Chunk{{0, 0}}},
+ {1, 2, []Chunk{{0, 0}}},
+ {2, 1, []Chunk{{0, 0}, {1, 1}}},
+ {10, 9, []Chunk{{0, 8}, {9, 9}}},
+ }
+
+ for _, tt := range cases {
+ got := slices.Collect(Of(tt.total, tt.chunkSize))
+ if !slices.Equal(got, tt.want) {
+ t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
+ }
+ }
+}
+
+func TestSize(t *testing.T) {
+ cases := []struct {
+ c Chunk
+ want int64
+ }{
+ {Chunk{0, 0}, 1},
+ {Chunk{0, 1}, 2},
+ {Chunk{3, 4}, 2},
+ }
+
+ for _, tt := range cases {
+ got := tt.c.Size()
+ if got != tt.want {
+ t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
+ }
+ }
+}
+
+func TestCount(t *testing.T) {
+ cases := []struct {
+ total int64
+ chunkSize int64
+ want int64
+ }{
+ {0, 1, 0},
+ {1, 1, 1},
+ {1, 2, 1},
+ {2, 1, 2},
+ {10, 9, 2},
+ }
+ for _, tt := range cases {
+ got := Count(tt.total, tt.chunkSize)
+ if got != tt.want {
+ t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
+ }
+ }
+}
diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go
new file mode 100644
index 000000000..136122721
--- /dev/null
+++ b/server/internal/client/ollama/registry.go
@@ -0,0 +1,802 @@
+// Package ollama provides a client for interacting with an Ollama registry
+// which pushes and pulls model manifests and layers as defined by the
+// [ollama.com/manifest].
+package ollama
+
+import (
+ "bufio"
+ "bytes"
+ "cmp"
+ "context"
+ "crypto"
+ "crypto/ed25519"
+ "crypto/sha256"
+ "crypto/tls"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/sync/errgroup"
+
+ "github.com/ollama/ollama/server/internal/cache/blob"
+ "github.com/ollama/ollama/server/internal/chunks"
+ "github.com/ollama/ollama/server/internal/internal/backoff"
+ "github.com/ollama/ollama/server/internal/internal/names"
+ "github.com/ollama/ollama/server/internal/internal/syncs"
+
+ _ "embed"
+)
+
+// Errors
+var (
+ // ErrManifestNotFound is returned when a manifest is not found in the
+ // cache or registry.
+ ErrManifestNotFound = errors.New("manifest not found")
+
+ // ErrManifestInvalid is returned when a manifest found in a local or
+ // remote cache is invalid.
+ ErrManifestInvalid = errors.New("invalid manifest")
+
+ // ErrMissingModel is returned when the model part of a name is missing
+ // or invalid.
+ ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
+
+ // ErrCached is passed to [Trace.PushUpdate] when a layer already
+ // exists. It is a non-fatal error and is never returned by [Registry.Push].
+ ErrCached = errors.New("cached")
+)
+
+// Defaults
+const (
+ // DefaultChunkingThreshold is the threshold at which a layer should be
+ // split up into chunks when downloading.
+ DefaultChunkingThreshold = 128 << 20
+
+ // DefaultMaxChunkSize is the default maximum size of a chunk to
+ // download. It is configured based on benchmarks and aims to strike a
+ // balance between download speed and memory usage.
+ DefaultMaxChunkSize = 8 << 20
+)
+
+// DefaultCache returns a new disk cache for storing models. If the
+// OLLAMA_MODELS environment variable is set, it uses that directory;
+// otherwise, it uses $HOME/.ollama/models.
+func DefaultCache() (*blob.DiskCache, error) {
+ dir := os.Getenv("OLLAMA_MODELS")
+ if dir == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil, err
+ }
+ dir = filepath.Join(home, ".ollama", "models")
+ }
+ return blob.Open(dir)
+}
+
+// Error is the standard error returned by Ollama APIs.
+type Error struct {
+ Status int `json:"-"`
+ Code string `json:"code"`
+ Message string `json:"message"`
+}
+
+func (e *Error) Error() string {
+ return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
+}
+
+// UnmarshalJSON implements json.Unmarshaler.
+func (e *Error) UnmarshalJSON(b []byte) error {
+ type E Error
+ var v struct{ Errors []E }
+ if err := json.Unmarshal(b, &v); err != nil {
+ return err
+ }
+ if len(v.Errors) == 0 {
+ return fmt.Errorf("no messages in error response: %s", string(b))
+ }
+ *e = Error(v.Errors[0]) // our registry only returns one error.
+ return nil
+}
+
+// TODO(bmizerany): make configurable on [Registry]
+var defaultName = func() names.Name {
+ n := names.Parse("ollama.com/library/_:latest")
+ if !n.IsFullyQualified() {
+ panic("default name is not fully qualified")
+ }
+ return n
+}()
+
+// Registry is a client for performing push and pull operations against an
+// Ollama registry.
+type Registry struct {
+ // UserAgent is the User-Agent header to send with requests to the
+ // registry. If empty, the User-Agent is determined by HTTPClient.
+ UserAgent string
+
+ // Key is the key used to authenticate with the registry.
+ //
+ // Currently, only Ed25519 keys are supported.
+ Key crypto.PrivateKey
+
+ // HTTPClient is the HTTP client used to make requests to the registry.
+ //
+ // If nil, [http.DefaultClient] is used.
+ //
+ // As a quick note: If a Registry function that makes a call to a URL
+ // with the "https+insecure" scheme, the client will be cloned and the
+ // transport will be set to skip TLS verification, unless the client's
+ // Transport done not have a Clone method with the same signature as
+ // [http.Transport.Clone], which case, the call will fail.
+ HTTPClient *http.Client
+
+ // MaxStreams is the maximum number of concurrent streams to use when
+ // pushing or pulling models. If zero, the number of streams is
+ // determined by [runtime.GOMAXPROCS].
+ //
+ // Clients that want "unlimited" streams should set this to a large
+ // number.
+ MaxStreams int
+
+ // ChunkingThreshold is the maximum size of a layer to download in a single
+ // request. If zero, [DefaultChunkingThreshold] is used.
+ ChunkingThreshold int64
+
+ // MaxChunkSize is the maximum size of a chunk to download. If zero,
+ // the default is [DefaultMaxChunkSize].
+ //
+ // It is only used when a layer is larger than [MaxChunkingThreshold].
+ MaxChunkSize int64
+}
+
+// RegistryFromEnv returns a new Registry configured from the environment. The
+// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
+// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
+// system's temporary directory.
+//
+// It returns an error if any configuration in the environment is invalid.
+func RegistryFromEnv() (*Registry, error) {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil, err
+ }
+ keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
+ if err != nil {
+ return nil, err
+ }
+
+ var rc Registry
+ rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
+ if err != nil {
+ return nil, err
+ }
+ maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
+ if maxStreams != "" {
+ var err error
+ rc.MaxStreams, err = strconv.Atoi(maxStreams)
+ if err != nil {
+ return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
+ }
+ }
+ return &rc, nil
+}
+
+type PushParams struct {
+ // From is an optional destination name for the model. If empty, the
+ // destination name is the same as the source name.
+ From string
+}
+
+// parseName parses name using [names.ParseExtended] and then merges the name with the
+// default name, and checks that the name is fully qualified. If a digest is
+// present, it parse and returns it with the other fields as their zero values.
+//
+// It returns an error if the name is not fully qualified, or if the digest, if
+// any, is invalid.
+//
+// The scheme is returned as provided by [names.ParseExtended].
+func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
+ scheme, n, ds := names.ParseExtended(s)
+ n = names.Merge(n, defaultName)
+ if ds != "" {
+ // Digest is present. Validate it.
+ d, err = blob.ParseDigest(ds)
+ if err != nil {
+ return "", names.Name{}, blob.Digest{}, err
+ }
+ }
+
+ // The name check is deferred until after the digest check because we
+ // say that digests take precedence over names, and so should there
+ // errors when being parsed.
+ if !n.IsFullyQualified() {
+ return "", names.Name{}, blob.Digest{}, ErrNameInvalid
+ }
+
+ scheme = cmp.Or(scheme, "https")
+ return scheme, n, d, nil
+}
+
+func (r *Registry) maxStreams() int {
+ n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
+
+ // Large downloads require a writter stream, so ensure we have at least
+ // two streams to avoid a deadlock.
+ return max(n, 2)
+}
+
+func (r *Registry) maxChunkingThreshold() int64 {
+ return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
+}
+
+// chunkSizeFor returns the chunk size for a layer of the given size. If the
+// size is less than or equal to the max chunking threshold, the size is
+// returned; otherwise, the max chunk size is returned.
+func (r *Registry) maxChunkSize() int64 {
+ return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
+}
+
+// Push pushes the model with the name in the cache to the remote registry.
+func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
+ if p == nil {
+ p = &PushParams{}
+ }
+
+ m, err := ResolveLocal(c, cmp.Or(p.From, name))
+ if err != nil {
+ return err
+ }
+
+ // Before much else happens, check layers at not null, the blobs exist,
+ // and the sizes match. This prevents long uploads followed by
+ // disappointment.
+ for _, l := range m.Layers {
+ if l == nil {
+ return fmt.Errorf("%w: null layer", ErrManifestInvalid)
+ }
+ info, err := c.Get(l.Digest)
+ if err != nil {
+ return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
+ }
+ if info.Size != l.Size {
+ return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
+ }
+ }
+
+ t := traceFromContext(ctx)
+
+ scheme, n, _, err := parseName(name)
+ if err != nil {
+ // This should never happen since ResolveLocal should have
+ // already validated the name.
+ panic(err)
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ var g errgroup.Group
+ g.SetLimit(r.maxStreams())
+ for _, l := range m.Layers {
+ var progress atomic.Int64
+ g.Go(func() (err error) {
+ defer func() { t.update(l, progress.Load(), err) }()
+
+ t.update(l, 0, nil)
+
+ startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
+ scheme,
+ n.Host(),
+ n.Namespace(),
+ n.Model(),
+ l.Digest,
+ )
+ res, err := r.doOK(ctx, "POST", startURL, nil)
+ if err != nil {
+ return err
+ }
+ res.Body.Close()
+
+ f, err := os.Open(c.GetFile(l.Digest))
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ uploadURL := res.Header.Get("Location")
+ if uploadURL == "" {
+ t.update(l, l.Size, ErrCached)
+ return nil
+ }
+
+ req, err := r.newRequest(ctx, "PUT", uploadURL, f)
+ if err != nil {
+ return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
+ }
+ req.ContentLength = l.Size
+
+ res, err = doOK(r.client(), req)
+ if err == nil {
+ res.Body.Close()
+ }
+ return err
+ })
+ }
+
+ if err := g.Wait(); err != nil {
+ return err
+ }
+
+ // Commit
+ path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
+ scheme,
+ n.Host(),
+ n.Namespace(),
+ n.Model(),
+ n.Tag(),
+ )
+ res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
+ if err == nil {
+ res.Body.Close()
+ }
+ // TODO(bmizerany): add a "commit" trace event
+ return err
+}
+
+func canRetry(err error) bool {
+ var re *Error
+ if !errors.As(err, &re) {
+ return false
+ }
+ return re.Status >= 500
+}
+
+// Pull pulls the model with the given name from the remote registry into the
+// cache.
+//
+// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
+// chunks of the specified size, and then reassembled and verified. This is
+// typically slower than splitting the model up across layers, and is mostly
+// utilized for layers of type equal to "application/vnd.ollama.image".
+func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
+ scheme, n, _, err := parseName(name)
+ if err != nil {
+ return err
+ }
+
+ m, err := r.Resolve(ctx, name)
+ if err != nil {
+ return err
+ }
+ if len(m.Layers) == 0 {
+ return fmt.Errorf("%w: no layers", ErrManifestInvalid)
+ }
+
+ exists := func(l *Layer) bool {
+ info, err := c.Get(l.Digest)
+ return err == nil && info.Size == l.Size
+ }
+
+ t := traceFromContext(ctx)
+
+ var g errgroup.Group
+ g.SetLimit(r.maxStreams())
+
+ for _, l := range m.Layers {
+ if exists(l) {
+ t.update(l, l.Size, ErrCached)
+ continue
+ }
+
+ blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
+ req, err := r.newRequest(ctx, "GET", blobURL, nil)
+ if err != nil {
+ t.update(l, 0, err)
+ continue
+ }
+
+ t.update(l, 0, nil)
+
+ if l.Size <= r.maxChunkingThreshold() {
+ g.Go(func() error {
+ res, err := doOK(r.client(), req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+ err = c.Put(l.Digest, res.Body, l.Size)
+ if err == nil {
+ t.update(l, l.Size, nil)
+ }
+ return err
+ })
+ } else {
+ q := syncs.NewRelayReader()
+
+ g.Go(func() (err error) {
+ defer func() { q.CloseWithError(err) }()
+ return c.Put(l.Digest, q, l.Size)
+ })
+
+ var progress atomic.Int64
+
+ // We want to avoid extra round trips per chunk due to
+ // redirects from the registry to the blob store, so
+ // fire an initial request to get the final URL and
+ // then use that URL for the chunk requests.
+ req.Header.Set("Range", "bytes=0-0")
+ res, err := doOK(r.client(), req)
+ if err != nil {
+ return err
+ }
+ res.Body.Close()
+ req = res.Request.WithContext(req.Context())
+
+ streamNo := 0
+ tws := make([]*bufio.Writer, r.maxStreams()-1)
+ for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
+ ticket := q.Take()
+ bufIdx := streamNo % len(tws)
+ streamNo++
+ g.Go(func() (err error) {
+ defer func() {
+ if err != nil {
+ q.CloseWithError(err)
+ }
+ ticket.Close()
+ t.update(l, progress.Load(), err)
+ }()
+
+ for _, err := range backoff.Loop(ctx, 3*time.Second) {
+ if err != nil {
+ return err
+ }
+
+ err := func() error {
+ req := req.Clone(req.Context())
+ req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
+ res, err := doOK(r.client(), req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+
+ tw := tws[bufIdx]
+ if tw == nil {
+ tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
+ tws[bufIdx] = tw
+ }
+ tw.Reset(ticket)
+ defer tw.Reset(nil) // release ticket
+
+ _, err = io.CopyN(tw, res.Body, chunk.Size())
+ if err != nil {
+ return maybeUnexpectedEOF(err)
+ }
+ if err := tw.Flush(); err != nil {
+ return err
+ }
+
+ total := progress.Add(chunk.Size())
+ if total >= l.Size {
+ q.Close()
+ }
+ return nil
+ }()
+ if !canRetry(err) {
+ return err
+ }
+ }
+ return nil
+ })
+ }
+ }
+ }
+
+ if err := g.Wait(); err != nil {
+ return err
+ }
+
+ // store the manifest blob
+ md := blob.DigestFromBytes(m.Data)
+ if err := blob.PutBytes(c, md, m.Data); err != nil {
+ return err
+ }
+
+ // commit the manifest with a link
+ return c.Link(m.Name, md)
+}
+
+// Manifest represents a [ollama.com/manifest].
+type Manifest struct {
+ Name string `json:"-"` // the canonical name of the model
+ Data []byte `json:"-"` // the raw data of the manifest
+ Layers []*Layer `json:"layers"`
+}
+
+var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
+
+// Layer returns the layer with the given
+// digest, or nil if not found.
+func (m *Manifest) Layer(d blob.Digest) *Layer {
+ for _, l := range m.Layers {
+ if l.Digest == d {
+ return l
+ }
+ }
+ return nil
+}
+
+// MarshalJSON implements json.Marshaler.
+//
+// NOTE: It adds an empty config object to the manifest, which is required by
+// the registry, but not used by the client. In the future, the config object
+// will not be required by the registry and this will should be removed.
+func (m Manifest) MarshalJSON() ([]byte, error) {
+ type M Manifest
+ v := struct {
+ M
+
+ // This is ignored, mostly, by the registry But, if not
+ // present, it will cause an error to be returned during the
+ // last phase of the commit which expects it, but does nothing
+ // with it. This will be fixed in a future release of
+ // ollama.com.
+ Config *Layer `json:"config"`
+ }{
+ M: M(m),
+ Config: &Layer{Digest: emptyDigest},
+ }
+ return json.Marshal(v)
+}
+
+// unmarshalManifest unmarshals the data into a manifest, and sets the name
+// field to the string representation of the name.
+//
+// It panics if the name is not fully qualified. Callers should ensure the name
+// is fully qualified before calling this function.
+func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
+ if !n.IsFullyQualified() {
+ panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
+ }
+ var m Manifest
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil, err
+ }
+ m.Name = n.String()
+ m.Data = data
+ return &m, nil
+}
+
+// Layer is a layer in a model.
+type Layer struct {
+ Digest blob.Digest `json:"digest"`
+ MediaType string `json:"mediaType"`
+ Size int64 `json:"size"`
+}
+
+// ResolveLocal resolves a name to a Manifest in the local cache. The name is
+// parsed using [names.ParseExtended] but the scheme is ignored.
+func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
+ _, n, d, err := parseName(name)
+ if err != nil {
+ return nil, err
+ }
+ if !d.IsValid() {
+ d, err = c.Resolve(n.String())
+ if err != nil {
+ return nil, err
+ }
+ }
+ data, err := os.ReadFile(c.GetFile(d))
+ if err != nil {
+ if errors.Is(err, fs.ErrNotExist) {
+ return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
+ }
+ return nil, err
+ }
+ m, err := unmarshalManifest(n, data)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
+ }
+ return m, nil
+}
+
+// Resolve resolves a name to a Manifest in the remote registry.
+func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
+ scheme, n, d, err := parseName(name)
+ if err != nil {
+ return nil, err
+ }
+
+ manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
+ if d.IsValid() {
+ manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
+ }
+
+ res, err := r.doOK(ctx, "GET", manifestURL, nil)
+ if err != nil {
+ return nil, err
+ }
+ defer res.Body.Close()
+ data, err := io.ReadAll(res.Body)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(bmizerany): return digest here
+ m, err := unmarshalManifest(n, data)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
+ }
+ return m, nil
+}
+
+func (r *Registry) client() *http.Client {
+ if r.HTTPClient != nil {
+ return r.HTTPClient
+ }
+ return http.DefaultClient
+}
+
+// newRequest constructs a new request, ready to use, with the given method,
+// url, and body, presigned with client Key and UserAgent.
+func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
+ req, err := http.NewRequestWithContext(ctx, method, url, body)
+ if err != nil {
+ return nil, err
+ }
+ if r.UserAgent != "" {
+ req.Header.Set("User-Agent", r.UserAgent)
+ }
+ if r.Key != nil {
+ token, err := makeAuthToken(r.Key)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+ }
+ return req, nil
+}
+
+// doOK makes a request with the given client and request, and returns the
+// response if the status code is 200. If the status code is not 200, an Error
+// is parsed from the response body and returned. If any other error occurs, it
+// is returned.
+func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
+ if r.URL.Scheme == "https+insecure" {
+ // TODO(bmizerany): clone client.Transport, set
+ // InsecureSkipVerify, etc.
+
+ type cloner interface {
+ Clone() *http.Transport
+ }
+
+ // Attempt to configure the transport to skip TLS verification
+ // if we can clone it, otherwise fall through and let the http
+ // client complain and the scheme being invalid.
+ x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
+ if ok {
+ tr := x.Clone()
+ tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
+ tr.TLSClientConfig.InsecureSkipVerify = true
+
+ cc := *c // shallow copy
+ cc.Transport = tr
+ c = &cc
+
+ r = r.Clone(r.Context())
+ r.URL.Scheme = "https"
+
+ // fall through
+ }
+ }
+
+ res, err := c.Do(r)
+ if err != nil {
+ return nil, err
+ }
+ if res.StatusCode/100 != 2 {
+ out, err := io.ReadAll(res.Body)
+ if err != nil {
+ return nil, err
+ }
+ var re Error
+ if err := json.Unmarshal(out, &re); err != nil {
+ // Use the raw body if we can't parse it as an error object.
+ re.Message = string(out)
+ }
+ re.Status = res.StatusCode
+ return nil, &re
+ }
+ return res, nil
+}
+
+// doOK is a convenience method for making a request with newRequest and
+// passing it to doOK with r.client().
+func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
+ req, err := r.newRequest(ctx, method, path, body)
+ if err != nil {
+ return nil, err
+ }
+ return doOK(r.client(), req)
+}
+
+// makeAuthToken creates an Ollama auth token for the given private key.
+//
+// NOTE: This format is OLD, overly complex, and should be replaced. We're
+// inheriting it from the original Ollama client and ollama.com
+// implementations, so we need to support it for now.
+func makeAuthToken(key crypto.PrivateKey) (string, error) {
+ privKey, _ := key.(*ed25519.PrivateKey)
+ if privKey == nil {
+ return "", fmt.Errorf("unsupported private key type: %T", key)
+ }
+
+ url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
+ // Part 1: the checkData (e.g. the URL with a timestamp)
+
+ // Part 2: the public key
+ pubKeyShort, err := func() ([]byte, error) {
+ sshPubKey, err := ssh.NewPublicKey(privKey.Public())
+ if err != nil {
+ return nil, err
+ }
+ pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
+ if len(pubKeyParts) < 2 {
+ return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
+ }
+ pubKeyShort := pubKeyParts[1]
+ return pubKeyShort, nil
+ }()
+ if err != nil {
+ return "", err
+ }
+
+ // Part 3: the signature
+ sig := ed25519.Sign(*privKey, []byte(checkData(url)))
+
+ // Assemble the token: ::
+ var b strings.Builder
+ io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
+ b.WriteByte(':')
+ b.Write(pubKeyShort)
+ b.WriteByte(':')
+ io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
+
+ return b.String(), nil
+}
+
+// The original spec for Ollama tokens was to use the SHA256 of the zero
+// string as part of the signature. I'm not sure why that was, but we still
+// need it to verify the signature.
+var zeroSum = func() string {
+ sha256sum := sha256.Sum256(nil)
+ x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
+ return x
+}()
+
+// checkData takes a URL and creates the original string format of the
+// data signature that is used by the ollama client to sign requests
+func checkData(url string) string {
+ return fmt.Sprintf("GET,%s,%s", url, zeroSum)
+}
+
+func maybeUnexpectedEOF(err error) error {
+ if errors.Is(err, io.EOF) {
+ return io.ErrUnexpectedEOF
+ }
+ return err
+}
diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go
new file mode 100644
index 000000000..d8f2a4077
--- /dev/null
+++ b/server/internal/client/ollama/registry_test.go
@@ -0,0 +1,656 @@
+package ollama
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "math/rand/v2"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path"
+ "reflect"
+ "slices"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/ollama/ollama/server/internal/cache/blob"
+ "github.com/ollama/ollama/server/internal/chunks"
+ "github.com/ollama/ollama/server/internal/internal/testutil"
+)
+
+func TestManifestMarshalJSON(t *testing.T) {
+ // All manifests should contain an "empty" config object.
+ var m Manifest
+ data, err := json.Marshal(m)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
+ t.Error("expected manifest to contain empty config")
+ t.Fatalf("got:\n%s", string(data))
+ }
+}
+
+func link(c *blob.DiskCache, name string, manifest string) {
+ _, n, _, err := parseName(name)
+ if err != nil {
+ panic(err)
+ }
+ d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
+ if err != nil {
+ panic(err)
+ }
+ if err := c.Link(n.String(), d); err != nil {
+ panic(err)
+ }
+}
+
+var errRoundTrip = errors.New("forced roundtrip error")
+
+type recordRoundTripper http.HandlerFunc
+
+func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ w := httptest.NewRecorder()
+ rr(w, req)
+ if w.Code == 499 {
+ return nil, errRoundTrip
+ }
+ resp := w.Result()
+ // For some reason, Response.Request is not set by httptest.NewRecorder, so we
+ // set it manually.
+ resp.Request = req
+ return w.Result(), nil
+}
+
+// newClient constructs a cache with predefined manifests for testing. The manifests are:
+//
+// empty: no data
+// zero: no layers
+// single: one layer with the contents "exists"
+// multiple: two layers with the contents "exists" and "here"
+// notfound: a layer that does not exist in the cache
+// null: one null layer (e.g. [null])
+// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
+// invalid: a layer with invalid JSON data
+//
+// Tests that want to ensure the client does not communicate with the upstream
+// registry should pass a nil handler, which will cause a panic if
+// communication is attempted.
+//
+// To simulate a network error, pass a handler that returns a 499 status code.
+func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
+ t.Helper()
+ c, err := blob.Open(t.TempDir())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mklayer := func(data string) *Layer {
+ return &Layer{
+ Digest: importBytes(t, c, data),
+ Size: int64(len(data)),
+ }
+ }
+
+ commit := func(name string, layers ...*Layer) {
+ t.Helper()
+ data, err := json.Marshal(&Manifest{Layers: layers})
+ if err != nil {
+ t.Fatal(err)
+ }
+ link(c, name, string(data))
+ }
+
+ link(c, "empty", "")
+ commit("zero")
+ commit("single", mklayer("exists"))
+ commit("multiple", mklayer("exists"), mklayer("present"))
+ commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
+ commit("null", nil)
+ commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
+ link(c, "invalid", "!!!!!")
+
+ rc := &Registry{
+ HTTPClient: &http.Client{
+ Transport: recordRoundTripper(h),
+ },
+ }
+ return rc, c
+}
+
+func okHandler(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+}
+
+func checkErrCode(t *testing.T, err error, status int, code string) {
+ t.Helper()
+ var e *Error
+ if !errors.As(err, &e) || e.Status != status || e.Code != code {
+ t.Errorf("err = %v; want %v %v", err, status, code)
+ }
+}
+
+func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
+ d, err := c.Import(strings.NewReader(data), int64(len(data)))
+ if err != nil {
+ t.Fatal(err)
+ }
+ return d
+}
+
+func TestRegistryPushInvalidNames(t *testing.T) {
+ rc, c := newClient(t, nil)
+
+ cases := []struct {
+ name string
+ err error
+ }{
+ {"", ErrNameInvalid},
+ {"@", ErrNameInvalid},
+ {"@x", blob.ErrInvalidDigest},
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a new registry and push a new image.
+ err := rc.Push(t.Context(), c, tt.name, nil)
+ if !errors.Is(err, tt.err) {
+ t.Errorf("err = %v; want %v", err, tt.err)
+ }
+ })
+ }
+}
+
+func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
+ t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
+ return WithTrace(ctx, t), t
+}
+
+func TestPushZero(t *testing.T) {
+ rc, c := newClient(t, okHandler)
+ err := rc.Push(t.Context(), c, "empty", nil)
+ if !errors.Is(err, ErrManifestInvalid) {
+ t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
+ }
+}
+
+func TestPushSingle(t *testing.T) {
+ rc, c := newClient(t, okHandler)
+ err := rc.Push(t.Context(), c, "single", nil)
+ testutil.Check(t, err)
+}
+
+func TestPushMultiple(t *testing.T) {
+ rc, c := newClient(t, okHandler)
+ err := rc.Push(t.Context(), c, "multiple", nil)
+ testutil.Check(t, err)
+}
+
+func TestPushNotFound(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ t.Errorf("unexpected request: %v", r)
+ })
+ err := rc.Push(t.Context(), c, "notfound", nil)
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
+ }
+}
+
+func TestPushNullLayer(t *testing.T) {
+ rc, c := newClient(t, nil)
+ err := rc.Push(t.Context(), c, "null", nil)
+ if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
+ t.Errorf("err = %v; want invalid manifest", err)
+ }
+}
+
+func TestPushSizeMismatch(t *testing.T) {
+ rc, c := newClient(t, nil)
+ ctx, _ := withTraceUnexpected(t.Context())
+ got := rc.Push(ctx, c, "sizemismatch", nil)
+ if got == nil || !strings.Contains(got.Error(), "size mismatch") {
+ t.Errorf("err = %v; want size mismatch", got)
+ }
+}
+
+func TestPushInvalid(t *testing.T) {
+ rc, c := newClient(t, nil)
+ err := rc.Push(t.Context(), c, "invalid", nil)
+ if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
+ t.Errorf("err = %v; want invalid manifest", err)
+ }
+}
+
+func TestPushExistsAtRemote(t *testing.T) {
+ var pushed bool
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/uploads/") {
+ if !pushed {
+ // First push. Return an uploadURL.
+ pushed = true
+ w.Header().Set("Location", "http://blob.store/blobs/123")
+ return
+ }
+ w.WriteHeader(http.StatusAccepted)
+ return
+ }
+
+ io.Copy(io.Discard, r.Body)
+ w.WriteHeader(http.StatusOK)
+ })
+
+ rc.MaxStreams = 1 // prevent concurrent uploads
+
+ var errs []error
+ ctx := WithTrace(t.Context(), &Trace{
+ Update: func(_ *Layer, n int64, err error) {
+ // uploading one at a time so no need to lock
+ errs = append(errs, err)
+ },
+ })
+
+ check := testutil.Checker(t)
+
+ err := rc.Push(ctx, c, "single", nil)
+ check(err)
+
+ if !errors.Is(errors.Join(errs...), nil) {
+ t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
+ }
+
+ err = rc.Push(ctx, c, "single", nil)
+ check(err)
+}
+
+func TestPushRemoteError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/blobs/") {
+ w.WriteHeader(500)
+ io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
+ return
+ }
+ })
+ got := rc.Push(t.Context(), c, "single", nil)
+ checkErrCode(t, got, 500, "blob_error")
+}
+
+func TestPushLocationError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Location", ":///x")
+ w.WriteHeader(http.StatusAccepted)
+ })
+ got := rc.Push(t.Context(), c, "single", nil)
+ wantContains := "invalid upload URL"
+ if got == nil || !strings.Contains(got.Error(), wantContains) {
+ t.Errorf("err = %v; want to contain %v", got, wantContains)
+ }
+}
+
+func TestPushUploadRoundtripError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.Host == "blob.store" {
+ w.WriteHeader(499) // force RoundTrip error on upload
+ return
+ }
+ w.Header().Set("Location", "http://blob.store/blobs/123")
+ })
+ got := rc.Push(t.Context(), c, "single", nil)
+ if !errors.Is(got, errRoundTrip) {
+ t.Errorf("got = %v; want %v", got, errRoundTrip)
+ }
+}
+
+func TestPushUploadFileOpenError(t *testing.T) {
+ rc, c := newClient(t, okHandler)
+ ctx := WithTrace(t.Context(), &Trace{
+ Update: func(l *Layer, _ int64, err error) {
+ // Remove the file just before it is opened for upload,
+ // but after the initial Stat that happens before the
+ // upload starts
+ os.Remove(c.GetFile(l.Digest))
+ },
+ })
+ got := rc.Push(ctx, c, "single", nil)
+ if !errors.Is(got, fs.ErrNotExist) {
+ t.Errorf("got = %v; want fs.ErrNotExist", got)
+ }
+}
+
+func TestPushCommitRoundtripError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/blobs/") {
+ panic("unexpected")
+ }
+ w.WriteHeader(499) // force RoundTrip error
+ })
+ err := rc.Push(t.Context(), c, "zero", nil)
+ if !errors.Is(err, errRoundTrip) {
+ t.Errorf("err = %v; want %v", err, errRoundTrip)
+ }
+}
+
+func checkNotExist(t *testing.T, err error) {
+ t.Helper()
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Fatalf("err = %v; want fs.ErrNotExist", err)
+ }
+}
+
+func TestRegistryPullInvalidName(t *testing.T) {
+ rc, c := newClient(t, nil)
+ err := rc.Pull(t.Context(), c, "://")
+ if !errors.Is(err, ErrNameInvalid) {
+ t.Errorf("err = %v; want %v", err, ErrNameInvalid)
+ }
+}
+
+func TestRegistryPullInvalidManifest(t *testing.T) {
+ cases := []string{
+ "",
+ "null",
+ "!!!",
+ `{"layers":[]}`,
+ }
+
+ for _, resp := range cases {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, resp)
+ })
+ err := rc.Pull(t.Context(), c, "x")
+ if !errors.Is(err, ErrManifestInvalid) {
+ t.Errorf("err = %v; want invalid manifest", err)
+ }
+ }
+}
+
+func TestRegistryPullNotCached(t *testing.T) {
+ check := testutil.Checker(t)
+
+ var c *blob.DiskCache
+ var rc *Registry
+
+ d := blob.DigestFromBytes("some data")
+ rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/blobs/") {
+ io.WriteString(w, "some data")
+ return
+ }
+ fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
+ })
+
+ // Confirm that the layer does not exist locally
+ _, err := ResolveLocal(c, "model")
+ checkNotExist(t, err)
+
+ _, err = c.Get(d)
+ checkNotExist(t, err)
+
+ err = rc.Pull(t.Context(), c, "model")
+ check(err)
+
+ mw, err := rc.Resolve(t.Context(), "model")
+ check(err)
+ mg, err := ResolveLocal(c, "model")
+ check(err)
+ if !reflect.DeepEqual(mw, mg) {
+ t.Errorf("mw = %v; mg = %v", mw, mg)
+ }
+
+ // Confirm successful download
+ info, err := c.Get(d)
+ check(err)
+ if info.Digest != d {
+ t.Errorf("info.Digest = %v; want %v", info.Digest, d)
+ }
+ if info.Size != 9 {
+ t.Errorf("info.Size = %v; want %v", info.Size, 9)
+ }
+
+ data, err := os.ReadFile(c.GetFile(d))
+ check(err)
+ if string(data) != "some data" {
+ t.Errorf("data = %q; want %q", data, "exists")
+ }
+}
+
+func TestRegistryPullCached(t *testing.T) {
+ cached := blob.DigestFromBytes("exists")
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/blobs/") {
+ w.WriteHeader(499) // should not be called
+ return
+ }
+ if strings.Contains(r.URL.Path, "/manifests/") {
+ fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
+ }
+ })
+
+ var errs []error
+ var reads []int64
+ ctx := WithTrace(t.Context(), &Trace{
+ Update: func(d *Layer, n int64, err error) {
+ t.Logf("update %v %d %v", d, n, err)
+ reads = append(reads, n)
+ errs = append(errs, err)
+ },
+ })
+
+ ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ defer cancel()
+
+ err := rc.Pull(ctx, c, "single")
+ testutil.Check(t, err)
+
+ want := []int64{6}
+ if !errors.Is(errors.Join(errs...), ErrCached) {
+ t.Errorf("errs = %v; want %v", errs, ErrCached)
+ }
+ if !slices.Equal(reads, want) {
+ t.Errorf("pairs = %v; want %v", reads, want)
+ }
+}
+
+func TestRegistryPullManifestNotFound(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ })
+ err := rc.Pull(t.Context(), c, "notfound")
+ checkErrCode(t, err, 404, "")
+}
+
+func TestRegistryPullResolveRemoteError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
+ })
+ err := rc.Pull(t.Context(), c, "single")
+ checkErrCode(t, err, 500, "an_error")
+}
+
+func TestRegistryPullResolveRoundtripError(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/manifests/") {
+ w.WriteHeader(499) // force RoundTrip error
+ return
+ }
+ })
+ err := rc.Pull(t.Context(), c, "single")
+ if !errors.Is(err, errRoundTrip) {
+ t.Errorf("err = %v; want %v", err, errRoundTrip)
+ }
+}
+
+// TestRegistryPullMixedCachedNotCached tests that cached layers do not
+// interfere with pulling layers that are not cached
+func TestRegistryPullMixedCachedNotCached(t *testing.T) {
+ x := blob.DigestFromBytes("xxxxxx")
+ e := blob.DigestFromBytes("exists")
+ y := blob.DigestFromBytes("yyyyyy")
+
+ for i := range 10 {
+ t.Logf("iteration %d", i)
+
+ digests := []blob.Digest{x, e, y}
+
+ rand.Shuffle(len(digests), func(i, j int) {
+ digests[i], digests[j] = digests[j], digests[i]
+ })
+
+ manifest := fmt.Sprintf(`{
+ "layers": [
+ {"digest":"%s","size":6},
+ {"digest":"%s","size":6},
+ {"digest":"%s","size":6}
+ ]
+ }`, digests[0], digests[1], digests[2])
+
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ switch path.Base(r.URL.Path) {
+ case "latest":
+ io.WriteString(w, manifest)
+ case x.String():
+ io.WriteString(w, "xxxxxx")
+ case e.String():
+ io.WriteString(w, "exists")
+ case y.String():
+ io.WriteString(w, "yyyyyy")
+ default:
+ panic(fmt.Sprintf("unexpected request: %v", r))
+ }
+ })
+
+ ctx := WithTrace(t.Context(), &Trace{
+ Update: func(l *Layer, n int64, err error) {
+ t.Logf("update %v %d %v", l, n, err)
+ },
+ })
+
+ // Check that we pull all layers that we can.
+
+ err := rc.Pull(ctx, c, "mixed")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, d := range digests {
+ info, err := c.Get(d)
+ if err != nil {
+ t.Fatalf("Get(%v): %v", d, err)
+ }
+ if info.Size != 6 {
+ t.Errorf("info.Size = %v; want %v", info.Size, 6)
+ }
+ }
+ }
+}
+
+func TestRegistryPullChunking(t *testing.T) {
+ rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
+ if r.URL.Host != "blob.store" {
+ // The production registry redirects to the blob store.
+ http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
+ return
+ }
+ if strings.Contains(r.URL.Path, "/blobs/") {
+ rng := r.Header.Get("Range")
+ if rng == "" {
+ http.Error(w, "missing range", http.StatusBadRequest)
+ return
+ }
+ _, c, err := chunks.ParseRange(r.Header.Get("Range"))
+ if err != nil {
+ panic(err)
+ }
+ io.WriteString(w, "remote"[c.Start:c.End+1])
+ return
+ }
+ fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
+ })
+
+ // Force chunking by setting the threshold to less than the size of the
+ // layer.
+ rc.ChunkingThreshold = 3
+ rc.MaxChunkSize = 3
+
+ var reads []int64
+ ctx := WithTrace(t.Context(), &Trace{
+ Update: func(d *Layer, n int64, err error) {
+ if err != nil {
+ t.Errorf("update %v %d %v", d, n, err)
+ }
+ reads = append(reads, n)
+ },
+ })
+
+ err := rc.Pull(ctx, c, "remote")
+ testutil.Check(t, err)
+
+ want := []int64{0, 3, 6}
+ if !slices.Equal(reads, want) {
+ t.Errorf("reads = %v; want %v", reads, want)
+ }
+}
+
+func TestRegistryResolveByDigest(t *testing.T) {
+ check := testutil.Checker(t)
+
+ exists := blob.DigestFromBytes("exists")
+ rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
+ w.WriteHeader(499) // should not hit manifest endpoint
+ }
+ fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
+ })
+
+ _, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
+ check(err)
+}
+
+func TestInsecureSkipVerify(t *testing.T) {
+ exists := blob.DigestFromBytes("exists")
+
+ s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
+ }))
+ defer s.Close()
+
+ const name = "ollama.com/library/insecure"
+
+ var rc Registry
+ url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
+ _, err := rc.Resolve(t.Context(), url)
+ if err == nil || !strings.Contains(err.Error(), "failed to verify") {
+ t.Errorf("err = %v; want cert verifiction failure", err)
+ }
+
+ url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
+ _, err = rc.Resolve(t.Context(), url)
+ testutil.Check(t, err)
+}
+
+func TestCanRetry(t *testing.T) {
+ cases := []struct {
+ err error
+ want bool
+ }{
+ {nil, false},
+ {errors.New("x"), false},
+ {ErrCached, false},
+ {ErrManifestInvalid, false},
+ {ErrNameInvalid, false},
+ {&Error{Status: 100}, false},
+ {&Error{Status: 500}, true},
+ }
+ for _, tt := range cases {
+ if got := canRetry(tt.err); got != tt.want {
+ t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
+ }
+ }
+}
diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go
new file mode 100644
index 000000000..8e53040ad
--- /dev/null
+++ b/server/internal/client/ollama/trace.go
@@ -0,0 +1,48 @@
+package ollama
+
+import (
+ "context"
+)
+
+// Trace is a set of functions that are called to report progress during blob
+// downloads and uploads.
+type Trace struct {
+ // Update is called during [Registry.Push] and [Registry.Pull] to
+ // report the progress of blob uploads and downloads.
+ //
+ // It is called once at the beginning of the download with a zero n and
+ // then once per read operation with the number of bytes read so far,
+ // and an error if any.
+ //
+ // A function assigned must be safe for concurrent use. The function is
+ // called synchronously and so should not block or take long to run.
+ Update func(_ *Layer, n int64, _ error)
+}
+
+func (t *Trace) update(l *Layer, n int64, err error) {
+ if t.Update != nil {
+ t.Update(l, n, err)
+ }
+}
+
+type traceKey struct{}
+
+// WithTrace returns a context derived from ctx that uses t to report trace
+// events.
+func WithTrace(ctx context.Context, t *Trace) context.Context {
+ return context.WithValue(ctx, traceKey{}, t)
+}
+
+var emptyTrace = &Trace{}
+
+// traceFromContext returns the Trace associated with ctx, or an empty Trace if
+// none is found.
+//
+// It never returns nil.
+func traceFromContext(ctx context.Context) *Trace {
+ t, _ := ctx.Value(traceKey{}).(*Trace)
+ if t == nil {
+ return emptyTrace
+ }
+ return t
+}
diff --git a/server/internal/cmd/opp/internal/safetensors/safetensors.go b/server/internal/cmd/opp/internal/safetensors/safetensors.go
new file mode 100644
index 000000000..7f3e99798
--- /dev/null
+++ b/server/internal/cmd/opp/internal/safetensors/safetensors.go
@@ -0,0 +1,220 @@
+// safetensors provides a reader for the safetensor directories and files.
+package safetensors
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/fs"
+ "iter"
+ "slices"
+ "strconv"
+ "strings"
+)
+
+// Tensor represents a single tensor in a safetensors file.
+//
+// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
+//
+// It is not safe for use across multiple goroutines.
+type Tensor struct {
+ name string
+ dataType string
+ shape []int64
+
+ fsys fs.FS
+ fname string // entry name in fsys
+ offset int64
+ size int64
+}
+
+type Model struct {
+ fsys fs.FS
+}
+
+func Read(fsys fs.FS) (*Model, error) {
+ return &Model{fsys: fsys}, nil
+}
+
+func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
+ return func(yield func(*Tensor, error) bool) {
+ entries, err := fs.Glob(m.fsys, "*.safetensors")
+ if err != nil {
+ yield(nil, err)
+ return
+ }
+ for _, e := range entries {
+ tt, err := m.readTensors(e)
+ if err != nil {
+ yield(nil, err)
+ return
+ }
+ for _, t := range tt {
+ if !yield(t, nil) {
+ return
+ }
+ }
+ }
+ }
+}
+
+func (m *Model) readTensors(fname string) ([]*Tensor, error) {
+ f, err := m.fsys.Open(fname)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ finfo, err := f.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ headerSize, err := readInt64(f)
+ if err != nil {
+ return nil, err
+ }
+
+ data := make([]byte, headerSize)
+ _, err = io.ReadFull(f, data)
+ if err != nil {
+ return nil, err
+ }
+
+ var raws map[string]json.RawMessage
+ if err := json.Unmarshal(data, &raws); err != nil {
+ return nil, err
+ }
+
+ // TODO(bmizerany): do something with metadata? This could be another
+ // header read if needed. We also need to figure out if the metadata is
+ // present in only one .safetensors file or if each file may have their
+ // own and if it needs to follow each tensor. Currently, I (bmizerany)
+ // am only seeing them show up with one entry for file type which is
+ // always "pt".
+
+ tt := make([]*Tensor, 0, len(raws))
+ for name, raw := range raws {
+ if !strings.HasPrefix(name, "model.layer") {
+ continue
+ }
+ var v struct {
+ DataType string `json:"dtype"`
+ Shape []int64 `json:"shape"`
+ Offsets []int64 `json:"data_offsets"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
+ }
+ if len(v.Offsets) != 2 {
+ return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
+ }
+
+ // TODO(bmizerany): after collecting, validate all offests make
+ // tensors contiguous?
+ begin, end := v.Offsets[0], v.Offsets[1]
+ if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
+ return nil, err
+ }
+
+ // TODO(bmizerany): just yield.. don't be silly and make a slice :)
+ tt = append(tt, &Tensor{
+ name: name,
+ dataType: v.DataType,
+ shape: v.Shape,
+ fsys: m.fsys,
+ fname: fname,
+ offset: begin,
+ size: end - begin,
+ })
+ }
+ return tt, nil
+}
+
+func checkBeginEnd(size, begin, end int64) error {
+ if begin < 0 {
+ return fmt.Errorf("begin must not be negative: %d", begin)
+ }
+ if end < 0 {
+ return fmt.Errorf("end must not be negative: %d", end)
+ }
+ if end < begin {
+ return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
+ }
+ if end > size {
+ return fmt.Errorf("end must be <= size: %d > %d", end, size)
+ }
+ return nil
+}
+
+func readInt64(r io.Reader) (int64, error) {
+ var v uint64
+ var buf [8]byte
+ if _, err := io.ReadFull(r, buf[:]); err != nil {
+ return 0, err
+ }
+ for i := range buf {
+ v |= uint64(buf[i]) << (8 * i)
+ }
+ return int64(v), nil
+}
+
+type Shape []int64
+
+func (s Shape) String() string {
+ var b strings.Builder
+ b.WriteByte('[')
+ for i, v := range s {
+ if i > 0 {
+ b.WriteByte(',')
+ }
+ b.WriteString(strconv.FormatInt(v, 10))
+ }
+ b.WriteByte(']')
+ return b.String()
+}
+
+func (t *Tensor) Name() string { return t.name }
+func (t *Tensor) DataType() string { return t.dataType }
+func (t *Tensor) Size() int64 { return t.size }
+func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
+
+func (t *Tensor) Reader() (io.ReadCloser, error) {
+ f, err := t.fsys.Open(t.fname)
+ if err != nil {
+ return nil, err
+ }
+ r := newSectionReader(f, t.offset, t.size)
+ rc := struct {
+ io.Reader
+ io.Closer
+ }{r, f}
+ return rc, nil
+}
+
+// newSectionReader returns a new io.Reader that reads from r starting at
+// offset. It is a convenience function for creating a io.SectionReader when r
+// may not be an io.ReaderAt.
+//
+// If r is already a ReaderAt, it is returned directly, otherwise if r is an
+// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
+// beginning of the file.
+//
+// If r is an io.Seeker,
+// or slow path. The slow path is used when r does not implement io.ReaderAt,
+// in which case it must discard the data it reads.
+func newSectionReader(r io.Reader, offset, n int64) io.Reader {
+ if r, ok := r.(io.ReaderAt); ok {
+ return io.NewSectionReader(r, offset, n)
+ }
+ if r, ok := r.(io.ReadSeeker); ok {
+ r.Seek(offset, io.SeekStart)
+ return io.LimitReader(r, n)
+ }
+ // Discard to offset and return a limited reader.
+ _, err := io.CopyN(io.Discard, r, offset)
+ if err != nil {
+ return nil
+ }
+ return io.LimitReader(r, n)
+}
diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go
new file mode 100644
index 000000000..12199cf3e
--- /dev/null
+++ b/server/internal/cmd/opp/opp.go
@@ -0,0 +1,366 @@
+package main
+
+import (
+ "bytes"
+ "cmp"
+ "context"
+ "encoding/json"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "mime"
+ "net/http"
+ "os"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ollama/ollama/server/internal/cache/blob"
+ "github.com/ollama/ollama/server/internal/client/ollama"
+ "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
+ "golang.org/x/sync/errgroup"
+)
+
+var stdout io.Writer = os.Stdout
+
+const usage = `Opp is a tool for pushing and pulling Ollama models.
+
+Usage:
+
+ opp [flags]
+
+Commands:
+
+ push Upload a model to the Ollama server.
+ pull Download a model from the Ollama server.
+ import Import a model from a local safetensor directory.
+
+Examples:
+
+ # Pull a model from the Ollama server.
+ opp pull library/llama3.2:latest
+
+ # Push a model to the Ollama server.
+ opp push username/my_model:8b
+
+ # Import a model from a local safetensor directory.
+ opp import /path/to/safetensor
+
+Envionment Variables:
+
+ OLLAMA_MODELS
+ The directory where models are pushed and pulled from
+ (default ~/.ollama/models).
+`
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprint(os.Stderr, usage)
+ }
+ flag.Parse()
+
+ c, err := ollama.DefaultCache()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ rc, err := ollama.RegistryFromEnv()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ ctx := context.Background()
+
+ err = func() error {
+ switch cmd := flag.Arg(0); cmd {
+ case "pull":
+ return cmdPull(ctx, rc, c)
+ case "push":
+ return cmdPush(ctx, rc, c)
+ case "import":
+ return cmdImport(ctx, c)
+ default:
+ if cmd == "" {
+ flag.Usage()
+ } else {
+ fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
+ }
+ os.Exit(2)
+ return errors.New("unreachable")
+ }
+ }()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "opp: %v\n", err)
+ os.Exit(1)
+ }
+}
+
+func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+ model := flag.Arg(1)
+ if model == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ tr := http.DefaultTransport.(*http.Transport).Clone()
+ // TODO(bmizerany): configure transport?
+ rc.HTTPClient = &http.Client{Transport: tr}
+
+ var mu sync.Mutex
+ p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
+
+ var pb bytes.Buffer
+ printProgress := func() {
+ pb.Reset()
+ mu.Lock()
+ for d, s := range p {
+ // Write progress to a buffer first to avoid blocking
+ // on stdout while holding the lock.
+ stamp := time.Now().Format("2006/01/02 15:04:05")
+ fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
+ if s[0] == s[1] {
+ delete(p, d)
+ }
+ }
+ mu.Unlock()
+ io.Copy(stdout, &pb)
+ }
+
+ ctx = ollama.WithTrace(ctx, &ollama.Trace{
+ Update: func(l *ollama.Layer, n int64, err error) {
+ if err != nil && !errors.Is(err, ollama.ErrCached) {
+ fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
+ return
+ }
+
+ mu.Lock()
+ p[l.Digest] = [2]int64{l.Size, n}
+ mu.Unlock()
+ },
+ })
+
+ errc := make(chan error)
+ go func() {
+ errc <- rc.Pull(ctx, c, model)
+ }()
+
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+ for {
+ select {
+ case <-t.C:
+ printProgress()
+ case err := <-errc:
+ printProgress()
+ return err
+ }
+ }
+}
+
+func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+ args := flag.Args()[1:]
+ flag := flag.NewFlagSet("push", flag.ExitOnError)
+ flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: opp push \n")
+ flag.PrintDefaults()
+ }
+ flag.Parse(args)
+
+ model := flag.Arg(0)
+ if model == "" {
+ return fmt.Errorf("missing model argument")
+ }
+
+ from := cmp.Or(*flagFrom, model)
+ m, err := ollama.ResolveLocal(c, from)
+ if err != nil {
+ return err
+ }
+
+ ctx = ollama.WithTrace(ctx, &ollama.Trace{
+ Update: func(l *ollama.Layer, n int64, err error) {
+ switch {
+ case errors.Is(err, ollama.ErrCached):
+ fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
+ case err != nil:
+ fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
+ case n == 0:
+ l := m.Layer(l.Digest)
+ mt, p, _ := mime.ParseMediaType(l.MediaType)
+ mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
+ switch mt {
+ case "tensor":
+ fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
+ default:
+ fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
+ }
+ }
+ },
+ })
+
+ return rc.Push(ctx, c, model, &ollama.PushParams{
+ From: from,
+ })
+}
+
+type trackingReader struct {
+ io.Reader
+ n *atomic.Int64
+}
+
+func (r *trackingReader) Read(p []byte) (n int, err error) {
+ n, err = r.Reader.Read(p)
+ r.n.Add(int64(n))
+ return n, err
+}
+
+func cmdImport(ctx context.Context, c *blob.DiskCache) error {
+ args := flag.Args()[1:]
+ flag := flag.NewFlagSet("import", flag.ExitOnError)
+ flagAs := flag.String("as", "", "Import using the provided name.")
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: opp import \n")
+ flag.PrintDefaults()
+ }
+ flag.Parse(args)
+
+ dir := cmp.Or(flag.Arg(0), ".")
+ fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
+
+ m, err := safetensors.Read(os.DirFS(dir))
+ if err != nil {
+ return err
+ }
+
+ var total int64
+ var tt []*safetensors.Tensor
+ for t, err := range m.Tensors() {
+ if err != nil {
+ return err
+ }
+ tt = append(tt, t)
+ total += t.Size()
+ }
+
+ var n atomic.Int64
+ done := make(chan error)
+ go func() {
+ layers := make([]*ollama.Layer, len(tt))
+ var g errgroup.Group
+ g.SetLimit(runtime.GOMAXPROCS(0))
+ var ctxErr error
+ for i, t := range tt {
+ if ctx.Err() != nil {
+ // The context may cancel AFTER we exit the
+ // loop, and so if we use ctx.Err() after the
+ // loop we may report it as the error that
+ // broke the loop, when it was not. This can
+ // manifest as a false-negative, leading the
+ // user to think their import failed when it
+ // did not, so capture it if and only if we
+ // exit the loop because of a ctx.Err() and
+ // report it.
+ ctxErr = ctx.Err()
+ break
+ }
+ g.Go(func() (err error) {
+ rc, err := t.Reader()
+ if err != nil {
+ return err
+ }
+ defer rc.Close()
+ tr := &trackingReader{rc, &n}
+ d, err := c.Import(tr, t.Size())
+ if err != nil {
+ return err
+ }
+ if err := rc.Close(); err != nil {
+ return err
+ }
+
+ layers[i] = &ollama.Layer{
+ Digest: d,
+ Size: t.Size(),
+ MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
+ "name": t.Name(),
+ "dtype": t.DataType(),
+ "shape": t.Shape().String(),
+ }),
+ }
+
+ return nil
+ })
+ }
+
+ done <- func() error {
+ if err := errors.Join(g.Wait(), ctxErr); err != nil {
+ return err
+ }
+ m := &ollama.Manifest{Layers: layers}
+ data, err := json.MarshalIndent(m, "", " ")
+ if err != nil {
+ return err
+ }
+ d := blob.DigestFromBytes(data)
+ err = blob.PutBytes(c, d, data)
+ if err != nil {
+ return err
+ }
+ return c.Link(*flagAs, d)
+ }()
+ }()
+
+ fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
+
+ csiHideCursor(stdout)
+ defer csiShowCursor(stdout)
+
+ csiSavePos(stdout)
+ writeProgress := func() {
+ csiRestorePos(stdout)
+ nn := n.Load()
+ fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
+ formatNatural(nn),
+ formatNatural(total),
+ nn*100/total,
+ ansiClearToEndOfLine,
+ )
+ }
+
+ ticker := time.NewTicker(time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ writeProgress()
+ case err := <-done:
+ writeProgress()
+ return err
+ }
+ }
+}
+
+func formatNatural(n int64) string {
+ switch {
+ case n < 1024:
+ return fmt.Sprintf("%d B", n)
+ case n < 1024*1024:
+ return fmt.Sprintf("%.1f KB", float64(n)/1024)
+ case n < 1024*1024*1024:
+ return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
+ default:
+ return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
+ }
+}
+
+const ansiClearToEndOfLine = "\033[K"
+
+func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
+func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
+func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
+func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }
diff --git a/server/internal/cmd/oppbench/oppbench.go b/server/internal/cmd/oppbench/oppbench.go
new file mode 100644
index 000000000..7a5305947
--- /dev/null
+++ b/server/internal/cmd/oppbench/oppbench.go
@@ -0,0 +1,11 @@
+package main
+
+import (
+ "fmt"
+ "os"
+)
+
+func main() {
+ fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
+ os.Exit(1)
+}
diff --git a/server/internal/cmd/oppbench/oppbench_test.go b/server/internal/cmd/oppbench/oppbench_test.go
new file mode 100644
index 000000000..c71d6cded
--- /dev/null
+++ b/server/internal/cmd/oppbench/oppbench_test.go
@@ -0,0 +1,107 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ollama/ollama/server/internal/chunks"
+ "golang.org/x/sync/errgroup"
+)
+
+func BenchmarkDownload(b *testing.B) {
+ run := func(fileSize, chunkSize int64) {
+ name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
+ b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
+ }
+
+ run(100<<20, 8<<20)
+ run(100<<20, 16<<20)
+ run(100<<20, 32<<20)
+ run(100<<20, 64<<20)
+ run(100<<20, 128<<20) // 1 chunk
+}
+
+func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
+ const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
+ req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
+ res, err := c.Do(req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+
+ _, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
+ return err
+}
+
+var sleepTime atomic.Int64
+
+func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
+ client := &http.Client{
+ Transport: func() http.RoundTripper {
+ tr := http.DefaultTransport.(*http.Transport).Clone()
+ tr.DisableKeepAlives = true
+ return tr
+ }(),
+ }
+ defer client.CloseIdleConnections()
+
+ // warm up the client
+ run(context.Background(), client, chunks.New(0, 1<<20))
+
+ b.SetBytes(fileSize)
+ b.ReportAllocs()
+
+ // Give our CDN a min to breathe between benchmarks.
+ time.Sleep(time.Duration(sleepTime.Swap(3)))
+
+ for b.Loop() {
+ g, ctx := errgroup.WithContext(b.Context())
+ g.SetLimit(runtime.GOMAXPROCS(0))
+ for chunk := range chunks.Of(fileSize, chunkSize) {
+ g.Go(func() error { return run(ctx, client, chunk) })
+ }
+ if err := g.Wait(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkWrite(b *testing.B) {
+ b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
+}
+
+func benchmarkWrite(b *testing.B, chunkSize int) {
+ b.ReportAllocs()
+
+ dir := b.TempDir()
+ f, err := os.Create(filepath.Join(dir, "write-single"))
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer f.Close()
+
+ data := make([]byte, chunkSize)
+ b.SetBytes(int64(chunkSize))
+ r := bytes.NewReader(data)
+ for b.Loop() {
+ r.Reset(data)
+ _, err := io.Copy(f, r)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go
new file mode 100644
index 000000000..1f0634f7c
--- /dev/null
+++ b/server/internal/internal/backoff/backoff.go
@@ -0,0 +1,48 @@
+package backoff
+
+import (
+ "context"
+ "iter"
+ "math/rand/v2"
+ "time"
+)
+
+func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
+ var n int
+ return func(yield func(int, error) bool) {
+ var t *time.Timer
+ for {
+ if ctx.Err() != nil {
+ yield(n, ctx.Err())
+ return
+ }
+
+ if !yield(n, nil) {
+ return
+ }
+
+ n++
+
+ // n^2 backoff timer is a little smoother than the
+ // common choice of 2^n.
+ d := time.Duration(n*n) * 10 * time.Millisecond
+ if d > maxBackoff {
+ d = maxBackoff
+ }
+ // Randomize the delay between 0.5-1.5 x msec, in order
+ // to prevent accidental "thundering herd" problems.
+ d = time.Duration(float64(d) * (rand.Float64() + 0.5))
+
+ if t == nil {
+ t = time.NewTimer(d)
+ } else {
+ t.Reset(d)
+ }
+ select {
+ case <-ctx.Done():
+ t.Stop()
+ case <-t.C:
+ }
+ }
+ }
+}
diff --git a/server/internal/internal/backoff/backoff_synctest_test.go b/server/internal/internal/backoff/backoff_synctest_test.go
new file mode 100644
index 000000000..cf17ce80a
--- /dev/null
+++ b/server/internal/internal/backoff/backoff_synctest_test.go
@@ -0,0 +1,40 @@
+//go:build goexperiment.synctest
+
+package backoff
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "testing/synctest"
+ "time"
+)
+
+func TestLoop(t *testing.T) {
+ synctest.Run(func() {
+ last := -1
+
+ ctx, cancel := context.WithCancel(t.Context())
+ defer cancel()
+
+ for n, err := range Loop(ctx, 100*time.Millisecond) {
+ if !errors.Is(err, ctx.Err()) {
+ t.Errorf("err = %v, want nil", err)
+ }
+ if err != nil {
+ break
+ }
+ if n != last+1 {
+ t.Errorf("n = %d, want %d", n, last+1)
+ }
+ last = n
+ if n > 5 {
+ cancel()
+ }
+ }
+
+ if last != 6 {
+ t.Errorf("last = %d, want 6", last)
+ }
+ })
+}
diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go
new file mode 100644
index 000000000..bb8438a78
--- /dev/null
+++ b/server/internal/internal/backoff/backoff_test.go
@@ -0,0 +1,38 @@
+package backoff
+
+import (
+ "context"
+ "testing"
+ "testing/synctest"
+ "time"
+)
+
+func TestLoopAllocs(t *testing.T) {
+ for i := range 3 {
+ got := testing.AllocsPerRun(1000, func() {
+ for tick := range Loop(t.Context(), 1) {
+ if tick >= i {
+ break
+ }
+ }
+ })
+ want := float64(0)
+ if i > 0 {
+ want = 3 // due to time.NewTimer
+ }
+ if got > want {
+ t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
+ }
+ }
+}
+
+func BenchmarkLoop(b *testing.B) {
+ ctx := context.Background()
+ synctest.Run(func() {
+ for n := range Loop(ctx, 100*time.Millisecond) {
+ if n == b.N {
+ break
+ }
+ }
+ })
+}
diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go
new file mode 100644
index 000000000..361cce76f
--- /dev/null
+++ b/server/internal/internal/names/name.go
@@ -0,0 +1,229 @@
+package names
+
+import (
+ "cmp"
+ "fmt"
+ "strings"
+
+ "github.com/ollama/ollama/server/internal/internal/stringsx"
+)
+
+const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /:
+
+type Name struct {
+ // Make incomparable to enfoce use of Compare / Equal for
+ // case-insensitive comparisons.
+ _ [0]func()
+
+ h string
+ n string
+ m string
+ t string
+}
+
+// Parse parses and assembles a Name from a name string. The
+// format of a valid name string is:
+//
+// s:
+// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
+// { host } "/" { namespace } "/" { model } ":" { tag }
+// { host } "/" { namespace } "/" { model } "@" { digest }
+// { host } "/" { namespace } "/" { model }
+// { namespace } "/" { model } ":" { tag } "@" { digest }
+// { namespace } "/" { model } ":" { tag }
+// { namespace } "/" { model } "@" { digest }
+// { namespace } "/" { model }
+// { model } ":" { tag } "@" { digest }
+// { model } ":" { tag }
+// { model } "@" { digest }
+// { model }
+// "@" { digest }
+// host:
+// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
+// length: [1, 350]
+// namespace:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
+// length: [1, 80]
+// model:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
+// length: [1, 80]
+// tag:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
+// length: [1, 80]
+// digest:
+// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
+// length: [1, 80]
+//
+// The name returned is not guaranteed to be valid. If it is not valid, the
+// field values are left in an undefined state. Use [Name.IsValid] to check
+// if the name is valid.
+func Parse(s string) Name {
+ if len(s) > MaxNameLength {
+ return Name{}
+ }
+
+ var n Name
+ var tail string
+ var c byte
+ for {
+ s, tail, c = cutLastAny(s, "/:")
+ switch c {
+ case ':':
+ n.t = tail
+ continue // look for model
+ case '/':
+ n.h, n.n, _ = cutLastAny(s, "/")
+ n.m = tail
+ return n
+ case 0:
+ n.m = tail
+ return n
+ }
+ }
+}
+
+// ParseExtended parses and returns any scheme, Name, and digest from from s in
+// the the form [scheme://][name][@digest]. All parts are optional.
+//
+// If the scheme is present, it must be followed by "://". The digest is
+// prefixed by "@" and comes after the name. The name is parsed using [Parse].
+//
+// The scheme and digest are stripped before the name is parsed by [Parse].
+//
+// For convience, the scheme is never empty. If the scheme is not present, the
+// returned scheme is "https".
+//
+// Examples:
+//
+// http://ollama.com/bmizerany/smol:latest@digest
+// https://ollama.com/bmizerany/smol:latest
+// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
+func ParseExtended(s string) (scheme string, _ Name, digest string) {
+ i := strings.Index(s, "://")
+ if i >= 0 {
+ scheme = s[:i]
+ s = s[i+3:]
+ }
+ i = strings.LastIndex(s, "@")
+ if i >= 0 {
+ digest = s[i+1:]
+ s = s[:i]
+ }
+ return scheme, Parse(s), digest
+}
+
+func FormatExtended(scheme string, n Name, digest string) string {
+ var b strings.Builder
+ if scheme != "" {
+ b.WriteString(scheme)
+ b.WriteString("://")
+ }
+ b.WriteString(n.String())
+ if digest != "" {
+ b.WriteByte('@')
+ b.WriteString(digest)
+ }
+ return b.String()
+}
+
+// Merge merges two names into a single name. Non-empty host, namespace, and
+// tag parts of a take precedence over fields in b. The model field is left as
+// is.
+//
+// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
+// if the name is valid.
+func Merge(a, b Name) Name {
+ a.h = cmp.Or(a.h, b.h)
+ a.n = cmp.Or(a.n, b.n)
+ a.t = cmp.Or(a.t, b.t)
+ return a
+}
+
+// IsValid returns true if the name is valid.
+func (n Name) IsValid() bool {
+ if n.h != "" && !isValidHost(n.h) {
+ return false
+ }
+ if n.n != "" && !isValidNamespace(n.n) {
+ return false
+ }
+ if n.m != "" && !isValidModel(n.m) {
+ return false
+ }
+ if n.t != "" && !isValidTag(n.t) {
+ return false
+ }
+ return true
+}
+
+func (n Name) IsFullyQualified() bool {
+ return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
+}
+
+func isValidHost(_ string) bool {
+ return true // TODO: implement
+}
+
+func isValidNamespace(_ string) bool {
+ return true // TODO: implement
+}
+
+func isValidModel(_ string) bool {
+ return true // TODO: implement
+}
+
+func isValidTag(_ string) bool {
+ return true // TODO: implement
+}
+
+func (n Name) Host() string { return n.h }
+func (n Name) Namespace() string { return n.n }
+func (n Name) Model() string { return n.m }
+func (n Name) Tag() string { return n.t }
+
+// Compare compares n and o case-insensitively. It returns 0 if n and o are
+// equal, -1 if n sorts before o, and 1 if n sorts after o.
+func (n Name) Compare(o Name) int {
+ return cmp.Or(
+ stringsx.CompareFold(n.h, o.h),
+ stringsx.CompareFold(n.n, o.n),
+ stringsx.CompareFold(n.m, o.m),
+ stringsx.CompareFold(n.t, o.t),
+ )
+}
+
+// String returns the fully qualified name in the format
+// /:.
+func (n Name) String() string {
+ var b strings.Builder
+ if n.h != "" {
+ b.WriteString(n.h)
+ b.WriteByte('/')
+ }
+ if n.n != "" {
+ b.WriteString(n.n)
+ b.WriteByte('/')
+ }
+ b.WriteString(n.m)
+ if n.t != "" {
+ b.WriteByte(':')
+ b.WriteString(n.t)
+ }
+ return b.String()
+}
+
+func (n Name) GoString() string {
+ return fmt.Sprintf("", n.h, n.n, n.m, n.t)
+}
+
+// cutLastAny is like strings.Cut but scans in reverse for the last character
+// in chars. If no character is found, before is the empty string and after is
+// s. The returned sep is the byte value of the character in chars if one was
+// found; otherwise it is 0.
+func cutLastAny(s, chars string) (before, after string, sep byte) {
+ i := strings.LastIndexAny(s, chars)
+ if i >= 0 {
+ return s[:i], s[i+1:], s[i]
+ }
+ return "", s, 0
+}
diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go
new file mode 100644
index 000000000..760fec5fa
--- /dev/null
+++ b/server/internal/internal/names/name_test.go
@@ -0,0 +1,152 @@
+package names
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestParseName(t *testing.T) {
+ cases := []struct {
+ in string
+ want Name
+ }{
+ {"", Name{}},
+ {"m:t", Name{m: "m", t: "t"}},
+ {"m", Name{m: "m"}},
+ {"/m", Name{m: "m"}},
+ {"/n/m:t", Name{n: "n", m: "m", t: "t"}},
+ {"n/m", Name{n: "n", m: "m"}},
+ {"n/m:t", Name{n: "n", m: "m", t: "t"}},
+ {"n/m", Name{n: "n", m: "m"}},
+ {"n/m", Name{n: "n", m: "m"}},
+ {strings.Repeat("m", MaxNameLength+1), Name{}},
+ {"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
+ {"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
+
+ // Invalids
+ // TODO: {"n:t/m:t", Name{}},
+ // TODO: {"/h/n/m:t", Name{}},
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.in, func(t *testing.T) {
+ got := Parse(tt.in)
+ if got.Compare(tt.want) != 0 {
+ t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestString(t *testing.T) {
+ cases := []string{
+ "",
+ "m:t",
+ "m:t",
+ "m",
+ "n/m",
+ "n/m:t",
+ "n/m",
+ "n/m",
+ "h/n/m:t",
+ "ollama.com/library/_:latest",
+
+ // Special cased to "round trip" without the leading slash.
+ "/m",
+ "/n/m:t",
+ }
+ for _, s := range cases {
+ t.Run(s, func(t *testing.T) {
+ s = strings.TrimPrefix(s, "/")
+ if g := Parse(s).String(); g != s {
+ t.Errorf("parse(%q).String() = %q", s, g)
+ }
+ })
+ }
+}
+
+func TestParseExtended(t *testing.T) {
+ cases := []struct {
+ in string
+
+ wantScheme string
+ wantName Name
+ wantDigest string
+ }{
+ {"", "", Name{}, ""},
+ {"m", "", Name{m: "m"}, ""},
+ {"http://m", "http", Name{m: "m"}, ""},
+ {"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
+ {"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
+ }
+ for _, tt := range cases {
+ t.Run(tt.in, func(t *testing.T) {
+ scheme, name, digest := ParseExtended(tt.in)
+ if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
+ t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
+ }
+
+ // Round trip
+ if got := FormatExtended(scheme, name, digest); got != tt.in {
+ t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
+ }
+ })
+ }
+}
+
+func TestMerge(t *testing.T) {
+ cases := []struct {
+ a, b string
+ want string
+ }{
+ {"", "", ""},
+ {"m", "", "m"},
+ {"", "m", ""},
+ {"x", "y", "x"},
+ {"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
+ {"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
+
+ {"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
+ {"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
+ }
+ for _, tt := range cases {
+ t.Run("", func(t *testing.T) {
+ a, b := Parse(tt.a), Parse(tt.b)
+ got := Merge(a, b)
+ if got.Compare(Parse(tt.want)) != 0 {
+ t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseStringRoundTrip(t *testing.T) {
+ cases := []string{
+ "",
+ "m",
+ "m:t",
+ "n/m",
+ "n/m:t",
+ "n/m:t",
+ "n/m",
+ "n/m",
+ "h/n/m:t",
+ "ollama.com/library/_:latest",
+ }
+ for _, s := range cases {
+ t.Run(s, func(t *testing.T) {
+ if got := Parse(s).String(); got != s {
+ t.Errorf("parse(%q).String() = %q", s, got)
+ }
+ })
+ }
+}
+
+var junkName Name
+
+func BenchmarkParseName(b *testing.B) {
+ b.ReportAllocs()
+ for range b.N {
+ junkName = Parse("h/n/m:t")
+ }
+}
diff --git a/server/internal/internal/stringsx/stringsx.go b/server/internal/internal/stringsx/stringsx.go
new file mode 100644
index 000000000..6c7a8d20d
--- /dev/null
+++ b/server/internal/internal/stringsx/stringsx.go
@@ -0,0 +1,52 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package stringsx provides additional string manipulation functions
+// that aren't in the standard library's strings package or go4.org/mem.
+package stringsx
+
+import (
+ "unicode"
+ "unicode/utf8"
+)
+
+// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
+// like cmp.Compare, but case insensitively.
+func CompareFold(a, b string) int {
+ // Track our position in both strings
+ ia, ib := 0, 0
+ for ia < len(a) && ib < len(b) {
+ ra, wa := nextRuneLower(a[ia:])
+ rb, wb := nextRuneLower(b[ib:])
+ if ra < rb {
+ return -1
+ }
+ if ra > rb {
+ return 1
+ }
+ ia += wa
+ ib += wb
+ if wa == 0 || wb == 0 {
+ break
+ }
+ }
+
+ // If we've reached here, one or both strings are exhausted
+ // The shorter string is "less than" if they match up to this point
+ switch {
+ case ia == len(a) && ib == len(b):
+ return 0
+ case ia == len(a):
+ return -1
+ default:
+ return 1
+ }
+}
+
+// nextRuneLower returns the next rune in the string, lowercased, along with its
+// original (consumed) width in bytes. If the string is empty, it returns
+// (utf8.RuneError, 0)
+func nextRuneLower(s string) (r rune, width int) {
+ r, width = utf8.DecodeRuneInString(s)
+ return unicode.ToLower(r), width
+}
diff --git a/server/internal/internal/stringsx/stringsx_test.go b/server/internal/internal/stringsx/stringsx_test.go
new file mode 100644
index 000000000..8575c0b27
--- /dev/null
+++ b/server/internal/internal/stringsx/stringsx_test.go
@@ -0,0 +1,78 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package stringsx
+
+import (
+ "cmp"
+ "strings"
+ "testing"
+)
+
+func TestCompareFold(t *testing.T) {
+ tests := []struct {
+ a, b string
+ }{
+ // Basic ASCII cases
+ {"", ""},
+ {"a", "a"},
+ {"a", "A"},
+ {"A", "a"},
+ {"a", "b"},
+ {"b", "a"},
+ {"abc", "ABC"},
+ {"ABC", "abc"},
+ {"abc", "abd"},
+ {"abd", "abc"},
+
+ // Length differences
+ {"abc", "ab"},
+ {"ab", "abc"},
+
+ // Unicode cases
+ {"世界", "世界"},
+ {"Hello世界", "hello世界"},
+ {"世界Hello", "世界hello"},
+ {"世界", "世界x"},
+ {"世界x", "世界"},
+
+ // Special case folding examples
+ {"ß", "ss"}, // German sharp s
+ {"fi", "fi"}, // fi ligature
+ {"Σ", "σ"}, // Greek sigma
+ {"İ", "i\u0307"}, // Turkish dotted I
+
+ // Mixed cases
+ {"HelloWorld", "helloworld"},
+ {"HELLOWORLD", "helloworld"},
+ {"helloworld", "HELLOWORLD"},
+ {"HelloWorld", "helloworld"},
+ {"helloworld", "HelloWorld"},
+
+ // Edge cases
+ {" ", " "},
+ {"1", "1"},
+ {"123", "123"},
+ {"!@#", "!@#"},
+ }
+
+ wants := []int{}
+ for _, tt := range tests {
+ got := CompareFold(tt.a, tt.b)
+ want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
+ if got != want {
+ t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
+ }
+ wants = append(wants, want)
+ }
+
+ if n := testing.AllocsPerRun(1000, func() {
+ for i, tt := range tests {
+ if CompareFold(tt.a, tt.b) != wants[i] {
+ panic("unexpected")
+ }
+ }
+ }); n > 0 {
+ t.Errorf("allocs = %v; want 0", int(n))
+ }
+}
diff --git a/server/internal/internal/syncs/line.go b/server/internal/internal/syncs/line.go
new file mode 100644
index 000000000..021cd4c09
--- /dev/null
+++ b/server/internal/internal/syncs/line.go
@@ -0,0 +1,201 @@
+// Package syncs provides synchronization primitives.
+package syncs
+
+import (
+ "cmp"
+ "io"
+ "sync"
+)
+
+var closedChan = func() chan struct{} {
+ ch := make(chan struct{})
+ close(ch)
+ return ch
+}()
+
+// Ticket represents a ticket in a sequence of tickets. The zero value is
+// invalid. Use [Line.Take] to get a valid ticket.
+//
+// A Ticket is not safe for concurrent use.
+type Ticket struct {
+ ahead chan struct{} // ticket ahead of this one
+ ch chan struct{}
+}
+
+// Ready returns a channel that is closed when the ticket before this one is
+// done.
+//
+// It is incorrect to wait on Ready after the ticket is done.
+func (t *Ticket) Ready() chan struct{} {
+ return cmp.Or(t.ahead, closedChan)
+}
+
+// Done signals that this ticket is done and that the next ticket in line can
+// proceed.
+//
+// The first call to [Done] unblocks the ticket after it, if any. Subsequent
+// calls are no-ops.
+func (t *Ticket) Done() {
+ if t.ch != nil {
+ close(t.ch)
+ }
+ t.ch = nil
+}
+
+// Line is an ordered sequence of tickets waiting for their turn to proceed.
+//
+// To get a ticket use [Line.Take].
+// To signal that a ticket is done use [Ticket.Done].
+// To wait your turn use [Ticket.Ready].
+//
+// A Line is not safe for concurrent use.
+type Line struct {
+ last chan struct{} // last ticket in line
+}
+
+func (q *Line) Take() *Ticket {
+ t := &Ticket{
+ ahead: q.last,
+ ch: make(chan struct{}),
+ }
+ q.last = t.ch
+ return t
+}
+
+// RelayReader implements an [io.WriterTo] that yields the passed
+// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
+// the order they are taken. Each [io.WriteCloser] blocks until the previous
+// one is closed, or a call to [RelayReader.CloseWithError] is made.
+//
+// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
+//
+// It is not safe for concurrent use.
+type RelayReader struct {
+ line Line
+ t *Ticket
+ w io.Writer
+ n int64
+
+ mu sync.Mutex
+ err error // set by CloseWithError
+ closedCh chan struct{} // closed if err is set
+}
+
+var (
+ _ io.Closer = (*RelayReader)(nil)
+ _ io.WriterTo = (*RelayReader)(nil)
+ _ io.Reader = (*RelayReader)(nil)
+)
+
+func NewRelayReader() *RelayReader {
+ var q RelayReader
+ q.closedCh = make(chan struct{})
+ q.t = q.line.Take()
+ return &q
+}
+
+// CloseWithError terminates the line, unblocking any writer waiting for its
+// turn with the error, or [io.EOF] if err is nil. It is safe to call
+// [CloseWithError] multiple times and across multiple goroutines.
+//
+// If the line is already closed, [CloseWithError] is a no-op.
+//
+// It never returns an error.
+func (q *RelayReader) CloseWithError(err error) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.err == nil {
+ q.err = cmp.Or(q.err, err, io.EOF)
+ close(q.closedCh)
+ }
+ return nil
+}
+
+// Close closes the line. Any writer waiting for its turn will be unblocked
+// with an [io.ErrClosedPipe] error.
+//
+// It never returns an error.
+func (q *RelayReader) Close() error {
+ return q.CloseWithError(nil)
+}
+
+func (q *RelayReader) closed() <-chan struct{} {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.closedCh
+}
+
+func (q *RelayReader) Read(p []byte) (int, error) {
+ panic("RelayReader.Read is for show only; use WriteTo")
+}
+
+// WriteTo yields the writer w to the first writer in line and blocks until the
+// first call to [Close].
+//
+// It is safe to call [Take] concurrently with [WriteTo].
+func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
+ select {
+ case <-q.closed():
+ return 0, io.ErrClosedPipe
+ default:
+ }
+
+ // We have a destination writer; let the relay begin.
+ q.w = dst
+ q.t.Done()
+ <-q.closed()
+ return q.n, nil
+}
+
+// Take returns a writer that will be passed to the next writer in line.
+//
+// It is not safe for use across multiple goroutines.
+func (q *RelayReader) Take() io.WriteCloser {
+ return &relayWriter{q: q, t: q.line.Take()}
+}
+
+type relayWriter struct {
+ q *RelayReader
+ t *Ticket
+ ready bool
+}
+
+var _ io.StringWriter = (*relayWriter)(nil)
+
+// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
+// writer is ready. It returns io.ErrClosedPipe if the line is closed before
+// the writer is ready.
+func (w *relayWriter) Write(p []byte) (int, error) {
+ if !w.awaitTurn() {
+ return 0, w.q.err
+ }
+ n, err := w.q.w.Write(p)
+ w.q.n += int64(n)
+ return n, err
+}
+
+func (w *relayWriter) WriteString(s string) (int, error) {
+ if !w.awaitTurn() {
+ return 0, w.q.err
+ }
+ return io.WriteString(w.q.w, s)
+}
+
+// Close signals that the writer is done, unblocking the next writer in line.
+func (w *relayWriter) Close() error {
+ w.t.Done()
+ return nil
+}
+
+func (t *relayWriter) awaitTurn() (ok bool) {
+ if t.ready {
+ return true
+ }
+ select {
+ case <-t.t.Ready():
+ t.ready = true
+ return true
+ case <-t.q.closed():
+ return false
+ }
+}
diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go
new file mode 100644
index 000000000..d52160260
--- /dev/null
+++ b/server/internal/internal/syncs/line_test.go
@@ -0,0 +1,65 @@
+package syncs
+
+import (
+ "bytes"
+ "io"
+ "math/rand/v2"
+ "testing"
+ "testing/synctest"
+)
+
+func TestPipelineReadWriterTo(t *testing.T) {
+ for range 10 {
+ synctest.Run(func() {
+ q := NewRelayReader()
+
+ tickets := []struct {
+ io.WriteCloser
+ s string
+ }{
+ {q.Take(), "you"},
+ {q.Take(), " say hi,"},
+ {q.Take(), " and "},
+ {q.Take(), "I say "},
+ {q.Take(), "hello"},
+ }
+
+ rand.Shuffle(len(tickets), func(i, j int) {
+ tickets[i], tickets[j] = tickets[j], tickets[i]
+ })
+
+ var g Group
+ for i, t := range tickets {
+ g.Go(func() {
+ defer t.Close()
+ if i%2 == 0 {
+ // Use [relayWriter.WriteString]
+ io.WriteString(t.WriteCloser, t.s)
+ } else {
+ t.Write([]byte(t.s))
+ }
+ })
+ }
+
+ var got bytes.Buffer
+ var copyErr error // checked at end
+ g.Go(func() {
+ _, copyErr = io.Copy(&got, q)
+ })
+
+ synctest.Wait()
+
+ q.Close()
+ g.Wait()
+
+ if copyErr != nil {
+ t.Fatal(copyErr)
+ }
+
+ want := "you say hi, and I say hello"
+ if got.String() != want {
+ t.Fatalf("got %q, want %q", got.String(), want)
+ }
+ })
+ }
+}
diff --git a/server/internal/internal/syncs/syncs.go b/server/internal/internal/syncs/syncs.go
new file mode 100644
index 000000000..8f1b1e078
--- /dev/null
+++ b/server/internal/internal/syncs/syncs.go
@@ -0,0 +1,41 @@
+package syncs
+
+import (
+ "sync"
+ "sync/atomic"
+)
+
+// Group is a [sync.WaitGroup] with a Go method.
+type Group struct {
+ wg sync.WaitGroup
+ n atomic.Int64
+}
+
+func (g *Group) Go(f func()) {
+ g.wg.Add(1)
+ go func() {
+ g.n.Add(1) // Now we are running
+ defer func() {
+ g.wg.Done()
+ g.n.Add(-1) // Now we are done
+ }()
+ f()
+ }()
+}
+
+// Running returns the number of goroutines that are currently running.
+//
+// If a call to [Running] returns zero, and a call to [Wait] is made without
+// any calls to [Go], then [Wait] will return immediately. This is true even if
+// a goroutine is started and finishes between the two calls.
+//
+// It is possible for [Running] to return non-zero and for [Wait] to return
+// immediately. This can happen if the all running goroutines finish between
+// the two calls.
+func (g *Group) Running() int64 {
+ return g.n.Load()
+}
+
+func (g *Group) Wait() {
+ g.wg.Wait()
+}
diff --git a/server/internal/internal/testutil/testutil.go b/server/internal/internal/testutil/testutil.go
new file mode 100644
index 000000000..354c2608c
--- /dev/null
+++ b/server/internal/internal/testutil/testutil.go
@@ -0,0 +1,74 @@
+package testutil
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+// Check calls t.Fatal(err) if err is not nil.
+func Check(t *testing.T, err error) {
+ if err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+}
+
+// CheckFunc exists so other packages do not need to invent their own type for
+// taking a Check function.
+type CheckFunc func(err error)
+
+// Checker returns a check function that
+// calls t.Fatal if err is not nil.
+func Checker(t *testing.T) (check func(err error)) {
+ return func(err error) {
+ if err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+ }
+}
+
+// StopPanic runs f but silently recovers from any panic f causes.
+// The normal usage is:
+//
+// testutil.StopPanic(func() {
+// callThatShouldPanic()
+// t.Errorf("callThatShouldPanic did not panic")
+// })
+func StopPanic(f func()) {
+ defer func() { recover() }()
+ f()
+}
+
+// CheckTime calls t.Fatalf if got != want. Included in the error message is
+// want.Sub(got) to help diagnose the difference, along with their values in
+// UTC.
+func CheckTime(t *testing.T, got, want time.Time) {
+ t.Helper()
+ if !got.Equal(want) {
+ t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
+ }
+}
+
+// WriteFile writes data to a file named name. It makes the directory if it
+// doesn't exist and sets the file mode to perm.
+//
+// The name must be a relative path and must not contain .. or start with a /;
+// otherwise WriteFile will panic.
+func WriteFile[S []byte | string](t testing.TB, name string, data S) {
+ t.Helper()
+
+ if filepath.IsAbs(name) {
+ t.Fatalf("WriteFile: name must be a relative path, got %q", name)
+ }
+ name = filepath.Clean(name)
+ dir := filepath.Dir(name)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/server/internal/manifest/manifest.go b/server/internal/manifest/manifest.go
new file mode 100644
index 000000000..e020d2c02
--- /dev/null
+++ b/server/internal/manifest/manifest.go
@@ -0,0 +1,118 @@
+// Package manifest provides documentation for the Ollama manifest format.
+// This package contains no code.
+//
+// # Manifests
+//
+// A manifest is a JSON object that describes a model. The JSON object has a
+// single field "layers" which is a list of layers that make up the model. Each
+// layer has the following fields:
+//
+// A layer is a single, logical unit of a model. Layers are stored in the cache
+// as files with the name of the digest of the layer. Layers are pushed and
+// pulled from the registry as blobs.
+//
+// A layer is represented as a JSON object with the following fields:
+//
+// - "digest": The digest of the layer.
+// - "mediaType": The media type of the layer.
+// - "size": The size of the layer in bytes.
+//
+// Layers are typically stored in a blob store, such as a registry, and are
+// referenced by their digest. This package does not define how layers are
+// stored or retrieved.
+//
+// # Configuration Layer
+//
+// The configuration of a model is represented as a layer with the media type:
+//
+// application/vnd.ollama.image.config; type=
+//
+// The "type" parameter in the media type specifies the format of the
+// configuration (e.g., "safetensor" or "gguf").
+//
+// There may be only one configuration layer in a model.
+//
+// # Template Layer
+//
+// The model template is a layer with the media type:
+//
+// application/vnd.ollama.image.template; [name=]
+//
+// The "name" parameter in the media type specifies the name of the template as
+// for lookup at runtime. The name is optional and may be omitted. If omitted,
+// the template is the default template for the model.
+//
+// # Tensor Layers
+//
+// The tensors of a model are represented as layers with the media type:
+//
+// application/vnd.ollama.image.tensor; name=; dtype=; shape=
+//
+// The "name" parameter in the media type specifies the name of the tensor as
+// defined in the model's configuration and are bound only by the rules for
+// names as defined in the configuration format, as represented by the
+// configuration's "type".
+//
+// The "dtype" parameter in the media type specifies the data type of the tensor
+// as a string.
+//
+// TODO: Define more specifically how to represent data types as strings.
+//
+// The "shape" parameter in the media type specifies the shape of the tensor as
+// a comma-separated list of integers; one per dimension.
+//
+// # Tokenization Layers
+//
+// The tokenization of a model is represented as a layer with the media type:
+//
+// application/vnd.ollama.image.tokenizer
+//
+// The configuration of the tokenizer is represented as a layer with the media type:
+//
+// application/vnd.ollama.image.tokenizer.config
+//
+// # Miscellaneous Layers
+//
+// These extra layer mime types are reserved:
+//
+// application/vnd.ollama.image.license
+//
+// This layer contains one of the many licenses for the model in plain text.
+//
+// # Example Manifest
+//
+// The following is an example manifest containing a configuration, a model
+// template, and two tensors (digests shortened for brevity):
+//
+// {
+// "layers": [{
+// "digest": "sha256:a...",
+// "mediaType": "application/vnd.ollama.image.config; type=safetensors",
+// "size": 1234
+// },{
+// "digest": "sha256:b...",
+// "mediaType": "application/vnd.ollama.image.template",
+// "size": 5678
+// },{
+// "digest": "sha256:c...",
+// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3",
+// "size": 9012
+// },{
+// "digest": "sha256:d...",
+// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6",
+// "size": 3456
+// }]
+// }
+//
+// # Legacy Media Types
+//
+// The appliaction/vnd.ollama.image.model media type is deprecated, but will
+// remain supported for backwards compatibility, for some undefined amount of
+// time. New models should use the media types defined above.
+//
+// # Reserved media types
+//
+// The media type prefix "application/vnd.ollama.image." is reserved for
+// defining new media types for layers known to Ollama. Currently, all other
+// prefixes are ignored by official Ollama registry clients.
+package manifest