server/internal: copy bmizerany/ollama-go to internal package (#9294)

This commit copies (without history) the bmizerany/ollama-go repository
with the intention of integrating it into the ollama as a replacement
for the pushing, and pulling of models, and management of the cache they
are pushed and pulled from.

New homes for these packages will be determined as they are integrated
and we have a better understanding of proper package boundaries.
This commit is contained in:
Blake Mizerany 2025-02-24 22:39:44 -08:00 committed by GitHub
parent 0b7e1676eb
commit 348b3e0983
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 4974 additions and 6 deletions

View File

@ -147,6 +147,7 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5

2
.gitignore vendored
View File

@ -5,7 +5,6 @@
.swp
dist
build
ollama
.cache
*.exe
.idea
@ -14,3 +13,4 @@ test_data
__debug_bin*
llama/build
llama/vendor
/ollama

View File

@ -6,8 +6,6 @@ linters:
- bidichk
- bodyclose
- containedctx
- contextcheck
- errcheck
- gocheckcompilerdirectives
- gofmt
- gofumpt
@ -23,10 +21,11 @@ linters:
- staticcheck
- tenv
- unconvert
- unused
- usestdlibvars
- wastedassign
- whitespace
disable:
- usestdlibvars
- errcheck
linters-settings:
staticcheck:
checks:
@ -39,5 +38,4 @@ severity:
- gofmt
- goimports
- intrange
- usestdlibvars
severity: info

544
server/internal/cache/blob/cache.go vendored Normal file
View File

@ -0,0 +1,544 @@
// Package blob implements a content-addressable disk cache for blobs and
// manifests.
package blob
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/server/internal/internal/names"
)
// Entry contains metadata about a blob in the cache.
type Entry struct {
Digest Digest
Size int64
Time time.Time // when added to the cache
}
// DiskCache caches blobs and manifests on disk.
//
// The cache is rooted at a directory, which is created if it does not exist.
//
// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
// "manifests" subdirectory. A example directory structure might look like:
//
// <dir>/
// blobs/
// sha256-<digest> - <blob data>
// manifests/
// <host>/
// <namespace>/
// <name>/
// <tag> - <manifest data>
//
// The cache is safe for concurrent use.
//
// Name casing is preserved in the cache, but is not significant when resolving
// names. For example, "Foo" and "foo" are considered the same name.
//
// The cache is not safe for concurrent use. It guards concurrent writes, but
// does not prevent duplicated effort. Because blobs are immutable, duplicate
// writes should result in the same file being written to disk.
type DiskCache struct {
// Dir specifies the top-level directory where blobs and manifest
// pointers are stored.
dir string
now func() time.Time
testHookBeforeFinalWrite func(f *os.File)
}
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}
// Open opens a cache rooted at the given directory. If the directory does not
// exist, it is created. If the directory is not a directory, an error is
// returned.
func Open(dir string) (*DiskCache, error) {
if dir == "" {
return nil, errors.New("blob: empty directory name")
}
info, err := os.Stat(dir)
if err == nil && !info.IsDir() {
return nil, fmt.Errorf("%q is not a directory", dir)
}
if err := os.MkdirAll(dir, 0o777); err != nil {
return nil, err
}
subdirs := []string{"blobs", "manifests"}
for _, subdir := range subdirs {
if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
return nil, err
}
}
// TODO(bmizerany): support shards
c := &DiskCache{
dir: dir,
now: time.Now,
}
return c, nil
}
func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, Digest{}, err
}
defer f.Close()
h := sha256.New()
r := io.TeeReader(f, h)
data, err = io.ReadAll(io.LimitReader(r, limit))
if err != nil {
return nil, Digest{}, err
}
var d Digest
h.Sum(d.sum[:0])
return data, d, nil
}
//lint:ignore U1000 used for debugging purposes as needed in tests
var debug = false
// debugger returns a function that can be used to add a step to the error message.
// The error message will be a list of steps that were taken before the error occurred.
// The steps are added in the order they are called.
//
// To set the error message, call the returned function with an empty string.
//
//lint:ignore U1000 used for debugging purposes as needed in tests
func debugger(err *error) func(step string) {
if !debug {
return func(string) {}
}
var steps []string
return func(step string) {
if step == "" && *err != nil {
*err = fmt.Errorf("%q: %w", steps, *err)
return
}
steps = append(steps, step)
if len(steps) > 100 {
// shift hints in case of a bug that causes a lot of hints
copy(steps, steps[1:])
steps = steps[:100]
}
}
}
// Resolve resolves a name to a digest. The name is expected to
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
//
// If a name is provided for a manifest that exists in the cache, the digest
// of the manifest is returned. If there is no manifest in the cache, it
// returns [fs.ErrNotExist].
//
// To cover the case where a manifest may change without the cache knowing
// (e.g. it was reformatted or modified by hand), the manifest data read and
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
return ParseDigest(digest)
}
// We want to address manifests files by digest using Get. That requires
// them to be blobs. This cannot be directly accomplished by looking in
// the blob store because manifests can change without Ollama knowing
// (e.g. a user modifies a manifests by hand then pushes it to update
// their model). We also need to support the blob caches inherited from
// older versions of Ollama, which do not store manifests in the blob
// store, so for these cases, we need to handle adding the manifests to
// the blob store, just in time.
//
// So now we read the manifests file, hash it, and copy it to the blob
// store if it's not already there.
//
// This should be cheap because manifests are small, and accessed
// infrequently.
file, err := c.manifestPath(name)
if err != nil {
return Digest{}, err
}
data, d, err := readAndSum(file, 1<<20)
if err != nil {
return Digest{}, err
}
// Ideally we'd read the "manifest" file as a manifest to the blob file,
// but we are not changing this yet, so copy the manifest to the blob
// store so it can be addressed by digest subsequent calls to Get.
if err := PutBytes(c, d, data); err != nil {
return Digest{}, err
}
return d, nil
}
// Put writes a new blob to the cache, identified by its digest. The operation
// reads content from r, which must precisely match both the specified size and
// digest.
//
// Concurrent write safety is achieved through file locking. The implementation
// guarantees write integrity by enforcing size limits and content validation
// before allowing the file to reach its final state.
func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
return c.copyNamedFile(c.GetFile(d), r, d, size)
}
// Import imports a blob from the provided reader into the cache. It reads the
// entire content of the reader, calculates its digest, and stores it in the
// cache.
//
// Import should be considered unsafe for use with untrusted data, such as data
// read from a network. The caller is responsible for ensuring the integrity of
// the data being imported.
func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
// users that want to change the temp dir can set TEMPDIR.
f, err := os.CreateTemp("", "blob-")
if err != nil {
return Digest{}, err
}
defer os.Remove(f.Name())
// Copy the blob to a temporary file.
h := sha256.New()
r = io.TeeReader(r, h)
n, err := io.Copy(f, r)
if err != nil {
return Digest{}, err
}
if n != size {
return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
}
// Check the digest.
var d Digest
h.Sum(d.sum[:0])
if err := f.Close(); err != nil {
return Digest{}, err
}
name := c.GetFile(d)
// Rename the temporary file to the final file.
if err := os.Rename(f.Name(), name); err != nil {
return Digest{}, err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return d, nil
}
// Get retrieves a blob from the cache using the provided digest. The operation
// fails if the digest is malformed or if any errors occur during blob
// retrieval.
func (c *DiskCache) Get(d Digest) (Entry, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err != nil {
return Entry{}, err
}
if info.Size() == 0 {
return Entry{}, fs.ErrNotExist
}
return Entry{
Digest: d,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
// Link creates a symbolic reference in the cache that maps the provided name
// to a blob identified by its digest, making it retrievable by name using
// [Resolve].
//
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
if err != nil {
return err
}
defer f.Close()
// TODO(bmizerany): test this happens only if the blob was found to
// avoid leaving debris
if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
return err
}
info, err := f.Stat()
if err != nil {
return err
}
// Copy manifest to cache directory.
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink removes the any link for name. If the link does not exist, nothing
// happens, and no error is returned.
//
// It returns an error if the name is invalid or if the link removal encounters
// any issues.
func (c *DiskCache) Unlink(name string) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
// GetFile returns the absolute path to the file, in the cache, for the given
// digest. It does not check if the file exists.
//
// The returned path should not be stored, used outside the lifetime of the
// cache, or interpreted in any way.
func (c *DiskCache) GetFile(d Digest) string {
filename := fmt.Sprintf("sha256-%x", d.sum)
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
}
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
// path is assumed to be in filepath.ToSlash format.
func pathToName(s string) string {
s = strings.TrimPrefix(s, "manifests/")
rr := []rune(s)
for i := len(rr) - 1; i > 0; i-- {
if rr[i] == '/' {
rr[i] = ':'
return string(rr)
}
}
return s
}
// manifestPath finds the first manifest file on disk that matches the given
// name using a case-insensitive comparison. If no manifest file is found, it
// returns the path where the manifest file would be if it existed.
//
// If two manifest files exists on disk that match the given name using a
// case-insensitive comparison, the one that sorts first, lexically, is
// returned.
func (c *DiskCache) manifestPath(name string) (string, error) {
np, err := nameToPath(name)
if err != nil {
return "", err
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
}
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
f *os.File
err error
testHookBeforeFinalWrite func(*os.File)
}
func (w *checkWriter) seterr(err error) error {
if w.err == nil {
w.err = err
}
return err
}
// Write writes p to the underlying hash and writer. The last write to the
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
}
nextSize := w.n + int64(len(p))
if nextSize == w.size {
// last write. check hash.
sum := w.h.Sum(nil)
if !bytes.Equal(sum, w.d.sum[:]) {
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
}
if w.testHookBeforeFinalWrite != nil {
w.testHookBeforeFinalWrite(w.f)
}
}
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
// copyNamedFile copies file into name, expecting it to have the given Digest
// and size, if that file is not present already.
func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// File already exists with correct size. This is good enough.
// We can skip expensive hash checks.
//
// TODO: Do the hash check, but give caller a way to skip it.
return nil
}
// Copy file to cache directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0o666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
if err != nil {
f.Truncate(0)
return err
}
if n < size {
f.Truncate(0)
return io.ErrUnexpectedEOF
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove cached file.
os.Remove(name)
return err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return nil
}
func splitNameDigest(s string) (name, digest string) {
i := strings.LastIndexByte(s, '@')
if i < 0 {
return s, ""
}
return s[:i], s[i+1:]
}
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
}
return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
}
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
}
return abs
}

685
server/internal/cache/blob/cache_test.go vendored Normal file
View File

@ -0,0 +1,685 @@
package blob
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func init() {
debug = true
}
var epoch = func() time.Time {
d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
if d.IsZero() {
panic("time zero")
}
return d
}()
func TestOpenErrors(t *testing.T) {
exe, err := os.Executable()
if err != nil {
panic(err)
}
cases := []struct {
dir string
err string
}{
{t.TempDir(), ""},
{"", "empty directory name"},
{exe, "not a directory"},
}
for _, tt := range cases {
t.Run(tt.dir, func(t *testing.T) {
_, err := Open(tt.dir)
if tt.err == "" {
if err != nil {
t.Fatal(err)
}
return
}
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), tt.err) {
t.Fatalf("err = %v, want %q", err, tt.err)
}
})
}
}
func TestGetFile(t *testing.T) {
t.Chdir(t.TempDir())
c, err := Open(".")
if err != nil {
t.Fatal(err)
}
d := mkdigest("1")
got := c.GetFile(d)
cleaned := filepath.Clean(got)
if cleaned != got {
t.Fatalf("got is unclean: %q", got)
}
if !filepath.IsAbs(got) {
t.Fatal("got is not absolute")
}
abs, _ := filepath.Abs(c.dir)
if !strings.HasPrefix(got, abs) {
t.Fatalf("got is not local to %q", c.dir)
}
}
func TestBasic(t *testing.T) {
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
checkEntry := entryChecker(t, c)
checkFailed := func(err error) {
if err == nil {
t.Helper()
t.Fatal("expected error")
}
}
_, err = c.Resolve("invalid")
checkFailed(err)
_, err = c.Resolve("h/n/m:t")
checkFailed(err)
dx := mkdigest("x")
d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
if err != nil {
t.Fatal(err)
}
if d != dx {
t.Fatalf("d = %v, want %v", d, dx)
}
_, err = c.Get(Digest{})
checkFailed(err)
// not committed yet
_, err = c.Get(dx)
checkFailed(err)
err = PutBytes(c, dx, "!")
checkFailed(err)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
checkEntry(dx, 1, now)
t0 := now
now = now.Add(1*time.Hour + 1*time.Minute)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
// check not updated
checkEntry(dx, 1, t0)
}
type sleepFunc func(d time.Duration) time.Time
func openTester(t *testing.T) (*DiskCache, sleepFunc) {
t.Helper()
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
return c, func(d time.Duration) time.Time {
now = now.Add(d)
return now
}
}
func TestManifestPath(t *testing.T) {
check := testutil.Checker(t)
c, sleep := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
t0 := sleep(0)
sleep(1 * time.Hour)
err = c.Link("h/n/m:t", d1) // nop expected
check(err)
file := must(c.manifestPath("h/n/m:t"))
info, err := os.Stat(file)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestExistsWithoutBlob(t *testing.T) {
t.Chdir(t.TempDir())
check := testutil.Checker(t)
c, err := Open(".")
check(err)
checkEntry := entryChecker(t, c)
man := must(c.manifestPath("h/n/m:t"))
os.MkdirAll(filepath.Dir(man), 0o777)
testutil.WriteFile(t, man, "1")
got, err := c.Resolve("h/n/m:t")
check(err)
want := mkdigest("1")
if got != want {
t.Fatalf("got = %v, want %v", got, want)
}
e, err := c.Get(got)
check(err)
checkEntry(got, 1, e.Time)
}
func TestPut(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("hello, world")
err := PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
got, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("expected error, got %v", got)
}
// Put a valid blob
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with content that does not hash to the digest
err = PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put the valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob that errors during Read
err = c.Put(d, &errOnBangReader{s: "!"}, 1)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with mismatched size
err = c.Put(d, strings.NewReader("hello, world"), 11)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Final byte does not match the digest (testing commit phase)
err = PutBytes(c, d, "hello, world$")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
// change mode to read-only
f.Truncate(0)
f.Chmod(0o400)
f.Close()
f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { f1.Close() })
*f = *f1
})
defer reset()
err = PutBytes(c, d, "hello, world")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset()
}
func TestImport(t *testing.T) {
c, _ := openTester(t)
checkEntry := entryChecker(t, c)
want := mkdigest("x")
got, err := c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
got, err = c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
}
func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
old := c.testHookBeforeFinalWrite
c.testHookBeforeFinalWrite = h
return func() { c.testHookBeforeFinalWrite = old }
}
func TestPutGetZero(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("x")
err := PutBytes(c, d, "x")
check(err)
checkEntry(d, 1, sleep(0))
err = os.Truncate(c.GetFile(d), 0)
check(err)
_, err = c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func TestPutZero(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("x")
err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
testutil.Check(t, err)
checkNotExists(t, c, d)
}
func TestCommit(t *testing.T) {
check := testutil.Checker(t)
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
checkEntry := entryChecker(t, c)
now := epoch
c.now = func() time.Time { return now }
d1 := mkdigest("1")
err = c.Link("h/n/m:t", d1)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
err = PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
got, err := c.Resolve("h/n/m:t")
check(err)
if got != d1 {
t.Fatalf("d = %v, want %v", got, d1)
}
// commit again, more than 1 byte
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("h/n/m:t", d2)
check(err)
checkEntry(d2, 2, now)
filename := must(c.manifestPath("h/n/m:t"))
data, err := os.ReadFile(filename)
check(err)
if string(data) != "22" {
t.Fatalf("data = %q, want %q", data, "22")
}
t0 := now
now = now.Add(1 * time.Hour)
err = c.Link("h/n/m:t", d2) // same contents; nop
check(err)
info, err := os.Stat(filename)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestInvalidBlob(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("1")
err := c.Link("h/n/m:t", d)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
err = PutBytes(c, d, "1")
testutil.Check(t, err)
err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
if err != nil {
t.Fatal(err)
}
err = c.Link("h/n/m:t", d)
if !strings.Contains(err.Error(), "underfoot") {
t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
}
}
func TestManifestNameReuse(t *testing.T) {
t.Run("case-insensitive", func(t *testing.T) {
// This should run on all file system types.
testManifestNameReuse(t)
})
t.Run("case-sensitive", func(t *testing.T) {
useCaseInsensitiveTempDir(t)
testManifestNameReuse(t)
})
}
func testManifestNameReuse(t *testing.T) {
check := testutil.Checker(t)
c, _ := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("H/N/M:T", d2)
check(err)
var g [2]Digest
g[0], err = c.Resolve("h/n/m:t")
check(err)
g[1], err = c.Resolve("H/N/M:T")
check(err)
w := [2]Digest{d2, d2}
if g != w {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
// relink with different case
err = c.Unlink("h/n/m:t")
check(err)
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
// we should have only one link that is same case as the last link
want = []string{"manifests/h/n/m/T"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func TestManifestFile(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""},
// valid names
{"h/n/m:t", "/manifests/h/n/m/t"},
{"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
{"%/%/%/%", ""},
// already a path
{"h/n/m/t", ""},
// refs are not names
{"h/n/m:t@sha256-1", ""},
{"m@sha256-1", ""},
{"n/m:t@sha256-1", ""},
}
c, _ := openTester(t)
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := c.manifestPath(tt.in)
if err != nil && tt.want != "" {
t.Fatalf("unexpected error: %v", err)
}
if err == nil && tt.want == "" {
t.Fatalf("expected error")
}
dir := filepath.ToSlash(c.dir)
got = filepath.ToSlash(got)
got = strings.TrimPrefix(got, dir)
if got != tt.want {
t.Fatalf("got = %q, want %q", got, tt.want)
}
})
}
}
func TestNames(t *testing.T) {
c, _ := openTester(t)
check := testutil.Checker(t)
check(PutBytes(c, mkdigest("1"), "1"))
check(PutBytes(c, mkdigest("2"), "2"))
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func mkdigest(s string) Digest {
return Digest{sha256.Sum256([]byte(s))}
}
func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
t.Helper()
_, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
t.Helper()
return func(d Digest, size int64, mod time.Time) {
t.Helper()
t.Run("checkEntry:"+d.String(), func(t *testing.T) {
t.Helper()
defer func() {
if t.Failed() {
dumpCacheContents(t, c)
}
}()
e, err := c.Get(d)
if size == 0 && errors.Is(err, fs.ErrNotExist) {
err = nil
}
if err != nil {
t.Fatal(err)
}
if e.Digest != d {
t.Errorf("e.Digest = %v, want %v", e.Digest, d)
}
if e.Size != size {
t.Fatalf("e.Size = %v, want %v", e.Size, size)
}
testutil.CheckTime(t, e.Time, mod)
info, err := os.Stat(c.GetFile(d))
if err != nil {
t.Fatal(err)
}
if info.Size() != size {
t.Fatalf("info.Size = %v, want %v", info.Size(), size)
}
testutil.CheckTime(t, info.ModTime(), mod)
})
}
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}
func TestNameToPath(t *testing.T) {
_, err := nameToPath("h/n/m:t")
if err != nil {
t.Fatal(err)
}
}
type errOnBangReader struct {
s string
n int
}
func (e *errOnBangReader) Read(p []byte) (int, error) {
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
if e.n >= len(p) {
return 0, io.EOF
}
if e.s[e.n] == '!' {
return 0, errors.New("bang")
}
p[0] = e.s[e.n]
e.n++
return 1, nil
}
func dumpCacheContents(t *testing.T, c *DiskCache) {
t.Helper()
var b strings.Builder
fsys := os.DirFS(c.dir)
fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
t.Helper()
if err != nil {
return err
}
info, err := d.Info()
if err != nil {
return err
}
// Format like ls:
//
// ; ls -la
// drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
// drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
fmt.Fprintf(&b, " %s % 4d %s %s\n",
info.Mode(),
info.Size(),
info.ModTime().Format("Jan 2 15:04"),
path,
)
return nil
})
t.Log()
t.Logf("cache contents:\n%s", b.String())
}

View File

@ -0,0 +1,93 @@
package blob
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func isCaseSensitive(dir string) bool {
defer func() {
os.Remove(filepath.Join(dir, "_casecheck"))
}()
exists := func(file string) bool {
_, err := os.Stat(file)
return err == nil
}
file := filepath.Join(dir, "_casecheck")
FILE := filepath.Join(dir, "_CASECHECK")
if exists(file) || exists(FILE) {
panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
}
err := os.WriteFile(file, nil, 0o666)
if err != nil {
panic(err)
}
return !exists(FILE)
}
func isCI() bool {
return os.Getenv("CI") != ""
}
const volumeHint = `
Unable to locate case-insensitive TMPDIR on darwin.
To run tests, create the case-insensitive volume /Volumes/data:
$ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
or run with:
CI=1 go test ./...
`
// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
// can find one, otherwise it skips the test if the CI environment variable is
// set, or GOOS is not darwin.
func useCaseInsensitiveTempDir(t *testing.T) bool {
if isCaseSensitive(os.TempDir()) {
// Use the default temp dir if it is already case-sensitive.
return true
}
if runtime.GOOS == "darwin" {
// If darwin, check for the special case-sensitive volume and
// use it if available.
const volume = "/Volumes/data"
_, err := os.Stat(volume)
if err == nil {
tmpdir := filepath.Join(volume, "tmp")
os.MkdirAll(tmpdir, 0o700)
t.Setenv("TMPDIR", tmpdir)
return true
}
if isCI() {
// Special case darwin in CI; it is not case-sensitive
// by default, and we will be testing other platforms
// that are case-sensitive, so we'll have the test
// being skipped covered there.
t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
}
}
if !isCI() {
// Require devs to always tests with a case-insensitive TMPDIR.
// TODO(bmizerany): Print platform-specific instructions or
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Log(line)
}
}
return false
}

95
server/internal/cache/blob/digest.go vendored Normal file
View File

@ -0,0 +1,95 @@
package blob
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"strings"
)
var ErrInvalidDigest = errors.New("invalid digest")
// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
//
// It is comparable and can be used as a map key.
type Digest struct {
sum [32]byte
}
// ParseDigest parses a digest from a string. If the string is not a valid
// digest, a call to the returned digest's IsValid method will return false.
//
// The input string may be in one of two forms:
//
// - ("sha256-<hex>"), where <hex> is a 64-character hexadecimal string.
// - ("sha256:<hex>"), where <hex> is a 64-character hexadecimal string.
//
// The [Digest.String] method will return the canonical form of the
// digest, "sha256:<hex>".
func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
s := string(v)
i := strings.IndexAny(s, ":-")
var zero Digest
if i < 0 {
return zero, ErrInvalidDigest
}
prefix, sum := s[:i], s[i+1:]
if prefix != "sha256" || len(sum) != 64 {
return zero, ErrInvalidDigest
}
var d Digest
_, err := hex.Decode(d.sum[:], []byte(sum))
if err != nil {
return zero, ErrInvalidDigest
}
return d, nil
}
func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
return Digest{sha256.Sum256([]byte(v))}
}
// String returns the string representation of the digest in the conventional
// form "sha256:<hex>".
func (d Digest) String() string {
return fmt.Sprintf("sha256:%x", d.sum[:])
}
func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}
// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
// of some content.
func (d Digest) IsValid() bool {
return d != (Digest{})
}
// MarshalText implements the encoding.TextMarshaler interface. It returns an
// error if [Digest.IsValid] returns false.
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
// works for a zero digest. If [Digest.IsValid] returns true, it returns an
// error.
func (d *Digest) UnmarshalText(text []byte) error {
if *d != (Digest{}) {
return errors.New("digest: illegal UnmarshalText on valid digest")
}
v, err := ParseDigest(string(text))
if err != nil {
return err
}
*d = v
return nil
}

View File

@ -0,0 +1,63 @@
package blob
import (
"encoding/json"
"testing"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
valid bool
}{
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
// too short
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
// too long
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
// invalid prefix
{"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
// invalid hex
{"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
{"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
}
for _, tt := range cases {
got, err := ParseDigest(tt.in)
if tt.valid && err != nil {
t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
}
want := "sha256:" + tt.in[7:]
if tt.valid && got.String() != want {
t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
}
}
}
func TestDigestMarshalText(t *testing.T) {
const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
var d Digest
if err := json.Unmarshal([]byte(s), &d); err != nil {
t.Errorf("json.Unmarshal: %v", err)
}
out, err := json.Marshal(d)
if err != nil {
t.Errorf("json.Marshal: %v", err)
}
want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
if string(out) != want {
t.Errorf("json.Marshal: got %s, want %s", out, want)
}
if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
t.Errorf("json.Unmarshal: expected error")
}
}

View File

@ -0,0 +1,78 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@ -0,0 +1,65 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@ -0,0 +1,802 @@
// Package ollama provides a client for interacting with an Ollama registry
// which pushes and pulls model manifests and layers as defined by the
// [ollama.com/manifest].
package ollama
import (
"bufio"
"bytes"
"cmp"
"context"
"crypto"
"crypto/ed25519"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed"
)
// Errors
var (
// ErrManifestNotFound is returned when a manifest is not found in the
// cache or registry.
ErrManifestNotFound = errors.New("manifest not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
ErrManifestInvalid = errors.New("invalid manifest")
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
)
// DefaultCache returns a new disk cache for storing models. If the
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
}
// Error is the standard error returned by Ollama APIs.
type Error struct {
Status int `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
}
func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct{ Errors []E }
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
*e = Error(v.Errors[0]) // our registry only returns one error.
return nil
}
// TODO(bmizerany): make configurable on [Registry]
var defaultName = func() names.Name {
n := names.Parse("ollama.com/library/_:latest")
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
}
return n
}()
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
// Key is the key used to authenticate with the registry.
//
// Currently, only Ed25519 keys are supported.
Key crypto.PrivateKey
// HTTPClient is the HTTP client used to make requests to the registry.
//
// If nil, [http.DefaultClient] is used.
//
// As a quick note: If a Registry function that makes a call to a URL
// with the "https+insecure" scheme, the client will be cloned and the
// transport will be set to skip TLS verification, unless the client's
// Transport done not have a Clone method with the same signature as
// [http.Transport.Clone], which case, the call will fail.
HTTPClient *http.Client
// MaxStreams is the maximum number of concurrent streams to use when
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// Clients that want "unlimited" streams should set this to a large
// number.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
}
// RegistryFromEnv returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
//
// It returns an error if any configuration in the environment is invalid.
func RegistryFromEnv() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil {
return nil, err
}
var rc Registry
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, err
}
maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
if maxStreams != "" {
var err error
rc.MaxStreams, err = strconv.Atoi(maxStreams)
if err != nil {
return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
}
}
return &rc, nil
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// parseName parses name using [names.ParseExtended] and then merges the name with the
// default name, and checks that the name is fully qualified. If a digest is
// present, it parse and returns it with the other fields as their zero values.
//
// It returns an error if the name is not fully qualified, or if the digest, if
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
scheme, n, ds := names.ParseExtended(s)
n = names.Merge(n, defaultName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
}
// The name check is deferred until after the digest check because we
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
}
scheme = cmp.Or(scheme, "https")
return scheme, n, d, nil
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
m, err := ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
// Before much else happens, check layers at not null, the blobs exist,
// and the sizes match. This prevents long uploads followed by
// disappointment.
for _, l := range m.Layers {
if l == nil {
return fmt.Errorf("%w: null layer", ErrManifestInvalid)
}
info, err := c.Get(l.Digest)
if err != nil {
return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
}
if info.Size != l.Size {
return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
}
}
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
panic(err)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
var progress atomic.Int64
g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }()
t.update(l, 0, nil)
startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
res, err := r.doOK(ctx, "POST", startURL, nil)
if err != nil {
return err
}
res.Body.Close()
f, err := os.Open(c.GetFile(l.Digest))
if err != nil {
return err
}
defer f.Close()
uploadURL := res.Header.Get("Location")
if uploadURL == "" {
t.update(l, l.Size, ErrCached)
return nil
}
req, err := r.newRequest(ctx, "PUT", uploadURL, f)
if err != nil {
return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
}
req.ContentLength = l.Size
res, err = doOK(r.client(), req)
if err == nil {
res.Body.Close()
}
return err
})
}
if err := g.Wait(); err != nil {
return err
}
// Commit
path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
n.Tag(),
)
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil {
res.Body.Close()
}
// TODO(bmizerany): add a "commit" trace event
return err
}
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name)
if err != nil {
return err
}
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
t := traceFromContext(ctx)
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
if exists(l) {
t.update(l, l.Size, ErrCached)
continue
}
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
if err != nil {
t.update(l, 0, err)
continue
}
t.update(l, 0, nil)
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
res, err := doOK(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) {
defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
})
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := doOK(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
streamNo := 0
tws := make([]*bufio.Writer, r.maxStreams()-1)
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
ticket := q.Take()
bufIdx := streamNo % len(tws)
streamNo++
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := doOK(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := tws[bufIdx]
if tw == nil {
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
tws[bufIdx] = tw
}
tw.Reset(ticket)
defer tw.Reset(nil) // release ticket
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
}
}
if err := g.Wait(); err != nil {
return err
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given
// digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer {
for _, l := range m.Layers {
if l.Digest == d {
return l
}
}
return nil
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
// the registry, but not used by the client. In the future, the config object
// will not be required by the registry and this will should be removed.
func (m Manifest) MarshalJSON() ([]byte, error) {
type M Manifest
v := struct {
M
// This is ignored, mostly, by the registry But, if not
// present, it will cause an error to be returned during the
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config *Layer `json:"config"`
}{
M: M(m),
Config: &Layer{Digest: emptyDigest},
}
return json.Marshal(v)
}
// unmarshalManifest unmarshals the data into a manifest, and sets the name
// field to the string representation of the name.
//
// It panics if the name is not fully qualified. Callers should ensure the name
// is fully qualified before calling this function.
func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
if !n.IsFullyQualified() {
panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
}
var m Manifest
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}
m.Name = n.String()
m.Data = data
return &m, nil
}
// Layer is a layer in a model.
type Layer struct {
Digest blob.Digest `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name)
if err != nil {
return nil, err
}
if !d.IsValid() {
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
}
}
data, err := os.ReadFile(c.GetFile(d))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
}
return nil, err
}
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name)
if err != nil {
return nil, err
}
manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
if d.IsValid() {
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
}
res, err := r.doOK(ctx, "GET", manifestURL, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// TODO(bmizerany): return digest here
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
}
return http.DefaultClient
}
// newRequest constructs a new request, ready to use, with the given method,
// url, and body, presigned with client Key and UserAgent.
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if r.UserAgent != "" {
req.Header.Set("User-Agent", r.UserAgent)
}
if r.Key != nil {
token, err := makeAuthToken(r.Key)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
}
return req, nil
}
// doOK makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.
type cloner interface {
Clone() *http.Transport
}
// Attempt to configure the transport to skip TLS verification
// if we can clone it, otherwise fall through and let the http
// client complain and the scheme being invalid.
x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
if ok {
tr := x.Clone()
tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
tr.TLSClientConfig.InsecureSkipVerify = true
cc := *c // shallow copy
cc.Transport = tr
c = &cc
r = r.Clone(r.Context())
r.URL.Scheme = "https"
// fall through
}
}
res, err := c.Do(r)
if err != nil {
return nil, err
}
if res.StatusCode/100 != 2 {
out, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var re Error
if err := json.Unmarshal(out, &re); err != nil {
// Use the raw body if we can't parse it as an error object.
re.Message = string(out)
}
re.Status = res.StatusCode
return nil, &re
}
return res, nil
}
// doOK is a convenience method for making a request with newRequest and
// passing it to doOK with r.client().
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body)
if err != nil {
return nil, err
}
return doOK(r.client(), req)
}
// makeAuthToken creates an Ollama auth token for the given private key.
//
// NOTE: This format is OLD, overly complex, and should be replaced. We're
// inheriting it from the original Ollama client and ollama.com
// implementations, so we need to support it for now.
func makeAuthToken(key crypto.PrivateKey) (string, error) {
privKey, _ := key.(*ed25519.PrivateKey)
if privKey == nil {
return "", fmt.Errorf("unsupported private key type: %T", key)
}
url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
// Part 1: the checkData (e.g. the URL with a timestamp)
// Part 2: the public key
pubKeyShort, err := func() ([]byte, error) {
sshPubKey, err := ssh.NewPublicKey(privKey.Public())
if err != nil {
return nil, err
}
pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
if len(pubKeyParts) < 2 {
return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
}
pubKeyShort := pubKeyParts[1]
return pubKeyShort, nil
}()
if err != nil {
return "", err
}
// Part 3: the signature
sig := ed25519.Sign(*privKey, []byte(checkData(url)))
// Assemble the token: <checkData>:<pubKey>:<signature>
var b strings.Builder
io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
b.WriteByte(':')
b.Write(pubKeyShort)
b.WriteByte(':')
io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
return b.String(), nil
}
// The original spec for Ollama tokens was to use the SHA256 of the zero
// string as part of the signature. I'm not sure why that was, but we still
// need it to verify the signature.
var zeroSum = func() string {
sha256sum := sha256.Sum256(nil)
x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
return x
}()
// checkData takes a URL and creates the original string format of the
// data signature that is used by the ollama client to sign requests
func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}

View File

@ -0,0 +1,656 @@
package ollama
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
data, err := json.Marshal(m)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
t.Error("expected manifest to contain empty config")
t.Fatalf("got:\n%s", string(data))
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
rr(w, req)
if w.Code == 499 {
return nil, errRoundTrip
}
resp := w.Result()
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
// set it manually.
resp.Request = req
return w.Result(), nil
}
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
mklayer := func(data string) *Layer {
return &Layer{
Digest: importBytes(t, c, data),
Size: int64(len(data)),
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
}
link(c, "empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
d, err := c.Import(strings.NewReader(data), int64(len(data)))
if err != nil {
t.Fatal(err)
}
return d
}
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), c, "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
pushed = true
w.Header().Set("Location", "http://blob.store/blobs/123")
return
}
w.WriteHeader(http.StatusAccepted)
return
}
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
rc.MaxStreams = 1 // prevent concurrent uploads
var errs []error
ctx := WithTrace(t.Context(), &Trace{
Update: func(_ *Layer, n int64, err error) {
// uploading one at a time so no need to lock
errs = append(errs, err)
},
})
check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, c, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), c, "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), c, "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
}
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), c, "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
}
func TestPushUploadFileOpenError(t *testing.T) {
rc, c := newClient(t, okHandler)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, _ int64, err error) {
// Remove the file just before it is opened for upload,
// but after the initial Stat that happens before the
// upload starts
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, c, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), c, "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
}
func TestRegistryPullInvalidManifest(t *testing.T) {
cases := []string{
"",
"null",
"!!!",
`{"layers":[]}`,
}
for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), c, "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, c, "single")
testutil.Check(t, err)
want := []int64{6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), c, "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), c, "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), c, "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, c, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
exists := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
w.WriteHeader(499) // should not hit manifest endpoint
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
})
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
check(err)
}
func TestInsecureSkipVerify(t *testing.T) {
exists := blob.DigestFromBytes("exists")
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
}))
defer s.Close()
const name = "ollama.com/library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verifiction failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
_, err = rc.Resolve(t.Context(), url)
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}

View File

@ -0,0 +1,48 @@
package ollama
import (
"context"
)
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
//
// It is called once at the beginning of the download with a zero n and
// then once per read operation with the number of bytes read so far,
// and an error if any.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
Update func(_ *Layer, n int64, _ error)
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}

View File

@ -0,0 +1,220 @@
// safetensors provides a reader for the safetensor directories and files.
package safetensors
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"slices"
"strconv"
"strings"
)
// Tensor represents a single tensor in a safetensors file.
//
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
//
// It is not safe for use across multiple goroutines.
type Tensor struct {
name string
dataType string
shape []int64
fsys fs.FS
fname string // entry name in fsys
offset int64
size int64
}
type Model struct {
fsys fs.FS
}
func Read(fsys fs.FS) (*Model, error) {
return &Model{fsys: fsys}, nil
}
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
return func(yield func(*Tensor, error) bool) {
entries, err := fs.Glob(m.fsys, "*.safetensors")
if err != nil {
yield(nil, err)
return
}
for _, e := range entries {
tt, err := m.readTensors(e)
if err != nil {
yield(nil, err)
return
}
for _, t := range tt {
if !yield(t, nil) {
return
}
}
}
}
}
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
f, err := m.fsys.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
finfo, err := f.Stat()
if err != nil {
return nil, err
}
headerSize, err := readInt64(f)
if err != nil {
return nil, err
}
data := make([]byte, headerSize)
_, err = io.ReadFull(f, data)
if err != nil {
return nil, err
}
var raws map[string]json.RawMessage
if err := json.Unmarshal(data, &raws); err != nil {
return nil, err
}
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
// own and if it needs to follow each tensor. Currently, I (bmizerany)
// am only seeing them show up with one entry for file type which is
// always "pt".
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
continue
}
var v struct {
DataType string `json:"dtype"`
Shape []int64 `json:"shape"`
Offsets []int64 `json:"data_offsets"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
}
if len(v.Offsets) != 2 {
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
}
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
tt = append(tt, &Tensor{
name: name,
dataType: v.DataType,
shape: v.Shape,
fsys: m.fsys,
fname: fname,
offset: begin,
size: end - begin,
})
}
return tt, nil
}
func checkBeginEnd(size, begin, end int64) error {
if begin < 0 {
return fmt.Errorf("begin must not be negative: %d", begin)
}
if end < 0 {
return fmt.Errorf("end must not be negative: %d", end)
}
if end < begin {
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
}
if end > size {
return fmt.Errorf("end must be <= size: %d > %d", end, size)
}
return nil
}
func readInt64(r io.Reader) (int64, error) {
var v uint64
var buf [8]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return 0, err
}
for i := range buf {
v |= uint64(buf[i]) << (8 * i)
}
return int64(v), nil
}
type Shape []int64
func (s Shape) String() string {
var b strings.Builder
b.WriteByte('[')
for i, v := range s {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(strconv.FormatInt(v, 10))
}
b.WriteByte(']')
return b.String()
}
func (t *Tensor) Name() string { return t.name }
func (t *Tensor) DataType() string { return t.dataType }
func (t *Tensor) Size() int64 { return t.size }
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
func (t *Tensor) Reader() (io.ReadCloser, error) {
f, err := t.fsys.Open(t.fname)
if err != nil {
return nil, err
}
r := newSectionReader(f, t.offset, t.size)
rc := struct {
io.Reader
io.Closer
}{r, f}
return rc, nil
}
// newSectionReader returns a new io.Reader that reads from r starting at
// offset. It is a convenience function for creating a io.SectionReader when r
// may not be an io.ReaderAt.
//
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
// beginning of the file.
//
// If r is an io.Seeker,
// or slow path. The slow path is used when r does not implement io.ReaderAt,
// in which case it must discard the data it reads.
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
if r, ok := r.(io.ReaderAt); ok {
return io.NewSectionReader(r, offset, n)
}
if r, ok := r.(io.ReadSeeker); ok {
r.Seek(offset, io.SeekStart)
return io.LimitReader(r, n)
}
// Discard to offset and return a limited reader.
_, err := io.CopyN(io.Discard, r, offset)
if err != nil {
return nil
}
return io.LimitReader(r, n)
}

View File

@ -0,0 +1,366 @@
package main
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
"golang.org/x/sync/errgroup"
)
var stdout io.Writer = os.Stdout
const usage = `Opp is a tool for pushing and pulling Ollama models.
Usage:
opp [flags] <push|pull|import>
Commands:
push Upload a model to the Ollama server.
pull Download a model from the Ollama server.
import Import a model from a local safetensor directory.
Examples:
# Pull a model from the Ollama server.
opp pull library/llama3.2:latest
# Push a model to the Ollama server.
opp push username/my_model:8b
# Import a model from a local safetensor directory.
opp import /path/to/safetensor
Envionment Variables:
OLLAMA_MODELS
The directory where models are pushed and pulled from
(default ~/.ollama/models).
`
func main() {
flag.Usage = func() {
fmt.Fprint(os.Stderr, usage)
}
flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
err = func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
return cmdPull(ctx, rc, c)
case "push":
return cmdPush(ctx, rc, c)
case "import":
return cmdImport(ctx, c)
default:
if cmd == "" {
flag.Usage()
} else {
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
}
os.Exit(2)
return errors.New("unreachable")
}
}()
if err != nil {
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
os.Exit(1)
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
os.Exit(1)
}
tr := http.DefaultTransport.(*http.Transport).Clone()
// TODO(bmizerany): configure transport?
rc.HTTPClient = &http.Client{Transport: tr}
var mu sync.Mutex
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
var pb bytes.Buffer
printProgress := func() {
pb.Reset()
mu.Lock()
for d, s := range p {
// Write progress to a buffer first to avoid blocking
// on stdout while holding the lock.
stamp := time.Now().Format("2006/01/02 15:04:05")
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
if s[0] == s[1] {
delete(p, d)
}
}
mu.Unlock()
io.Copy(stdout, &pb)
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
return
}
mu.Lock()
p[l.Digest] = [2]int64{l.Size, n}
mu.Unlock()
},
})
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, c, model)
}()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
printProgress()
case err := <-errc:
printProgress()
return err
}
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
flag.PrintDefaults()
}
flag.Parse(args)
model := flag.Arg(0)
if model == "" {
return fmt.Errorf("missing model argument")
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
if err != nil {
return err
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
switch {
case errors.Is(err, ollama.ErrCached):
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
case err != nil:
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
case n == 0:
l := m.Layer(l.Digest)
mt, p, _ := mime.ParseMediaType(l.MediaType)
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
switch mt {
case "tensor":
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
default:
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
}
}
},
})
return rc.Push(ctx, c, model, &ollama.PushParams{
From: from,
})
}
type trackingReader struct {
io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.n.Add(int64(n))
return n, err
}
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("import", flag.ExitOnError)
flagAs := flag.String("as", "", "Import using the provided name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
flag.PrintDefaults()
}
flag.Parse(args)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
m, err := safetensors.Read(os.DirFS(dir))
if err != nil {
return err
}
var total int64
var tt []*safetensors.Tensor
for t, err := range m.Tensors() {
if err != nil {
return err
}
tt = append(tt, t)
total += t.Size()
}
var n atomic.Int64
done := make(chan error)
go func() {
layers := make([]*ollama.Layer, len(tt))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
var ctxErr error
for i, t := range tt {
if ctx.Err() != nil {
// The context may cancel AFTER we exit the
// loop, and so if we use ctx.Err() after the
// loop we may report it as the error that
// broke the loop, when it was not. This can
// manifest as a false-negative, leading the
// user to think their import failed when it
// did not, so capture it if and only if we
// exit the loop because of a ctx.Err() and
// report it.
ctxErr = ctx.Err()
break
}
g.Go(func() (err error) {
rc, err := t.Reader()
if err != nil {
return err
}
defer rc.Close()
tr := &trackingReader{rc, &n}
d, err := c.Import(tr, t.Size())
if err != nil {
return err
}
if err := rc.Close(); err != nil {
return err
}
layers[i] = &ollama.Layer{
Digest: d,
Size: t.Size(),
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
"name": t.Name(),
"dtype": t.DataType(),
"shape": t.Shape().String(),
}),
}
return nil
})
}
done <- func() error {
if err := errors.Join(g.Wait(), ctxErr); err != nil {
return err
}
m := &ollama.Manifest{Layers: layers}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return err
}
d := blob.DigestFromBytes(data)
err = blob.PutBytes(c, d, data)
if err != nil {
return err
}
return c.Link(*flagAs, d)
}()
}()
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
csiHideCursor(stdout)
defer csiShowCursor(stdout)
csiSavePos(stdout)
writeProgress := func() {
csiRestorePos(stdout)
nn := n.Load()
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
formatNatural(nn),
formatNatural(total),
nn*100/total,
ansiClearToEndOfLine,
)
}
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
writeProgress()
case err := <-done:
writeProgress()
return err
}
}
}
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d B", n)
case n < 1024*1024:
return fmt.Sprintf("%.1f KB", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
}
}
const ansiClearToEndOfLine = "\033[K"
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }

View File

@ -0,0 +1,11 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -0,0 +1,48 @@
package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
if !yield(n, nil) {
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
}
}
}

View File

@ -0,0 +1,40 @@
//go:build goexperiment.synctest
package backoff
import (
"context"
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoop(t *testing.T) {
synctest.Run(func() {
last := -1
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
if n > 5 {
cancel()
}
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
}
})
}

View File

@ -0,0 +1,38 @@
package backoff
import (
"context"
"testing"
"testing/synctest"
"time"
)
func TestLoopAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
if tick >= i {
break
}
}
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
}
}
}
func BenchmarkLoop(b *testing.B) {
ctx := context.Background()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
if n == b.N {
break
}
}
})
}

View File

@ -0,0 +1,229 @@
package names
import (
"cmp"
"fmt"
"strings"
"github.com/ollama/ollama/server/internal/internal/stringsx"
)
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
type Name struct {
// Make incomparable to enfoce use of Compare / Equal for
// case-insensitive comparisons.
_ [0]func()
h string
n string
m string
t string
}
// Parse parses and assembles a Name from a name string. The
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
// if the name is valid.
func Parse(s string) Name {
if len(s) > MaxNameLength {
return Name{}
}
var n Name
var tail string
var c byte
for {
s, tail, c = cutLastAny(s, "/:")
switch c {
case ':':
n.t = tail
continue // look for model
case '/':
n.h, n.n, _ = cutLastAny(s, "/")
n.m = tail
return n
case 0:
n.m = tail
return n
}
}
}
// ParseExtended parses and returns any scheme, Name, and digest from from s in
// the the form [scheme://][name][@digest]. All parts are optional.
//
// If the scheme is present, it must be followed by "://". The digest is
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
//
// The scheme and digest are stripped before the name is parsed by [Parse].
//
// For convience, the scheme is never empty. If the scheme is not present, the
// returned scheme is "https".
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
func ParseExtended(s string) (scheme string, _ Name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, Parse(s), digest
}
func FormatExtended(scheme string, n Name, digest string) string {
var b strings.Builder
if scheme != "" {
b.WriteString(scheme)
b.WriteString("://")
}
b.WriteString(n.String())
if digest != "" {
b.WriteByte('@')
b.WriteString(digest)
}
return b.String()
}
// Merge merges two names into a single name. Non-empty host, namespace, and
// tag parts of a take precedence over fields in b. The model field is left as
// is.
//
// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
// if the name is valid.
func Merge(a, b Name) Name {
a.h = cmp.Or(a.h, b.h)
a.n = cmp.Or(a.n, b.n)
a.t = cmp.Or(a.t, b.t)
return a
}
// IsValid returns true if the name is valid.
func (n Name) IsValid() bool {
if n.h != "" && !isValidHost(n.h) {
return false
}
if n.n != "" && !isValidNamespace(n.n) {
return false
}
if n.m != "" && !isValidModel(n.m) {
return false
}
if n.t != "" && !isValidTag(n.t) {
return false
}
return true
}
func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
}
func isValidHost(_ string) bool {
return true // TODO: implement
}
func isValidNamespace(_ string) bool {
return true // TODO: implement
}
func isValidModel(_ string) bool {
return true // TODO: implement
}
func isValidTag(_ string) bool {
return true // TODO: implement
}
func (n Name) Host() string { return n.h }
func (n Name) Namespace() string { return n.n }
func (n Name) Model() string { return n.m }
func (n Name) Tag() string { return n.t }
// Compare compares n and o case-insensitively. It returns 0 if n and o are
// equal, -1 if n sorts before o, and 1 if n sorts after o.
func (n Name) Compare(o Name) int {
return cmp.Or(
stringsx.CompareFold(n.h, o.h),
stringsx.CompareFold(n.n, o.n),
stringsx.CompareFold(n.m, o.m),
stringsx.CompareFold(n.t, o.t),
)
}
// String returns the fully qualified name in the format
// <namespace>/<model>:<tag>.
func (n Name) String() string {
var b strings.Builder
if n.h != "" {
b.WriteString(n.h)
b.WriteByte('/')
}
if n.n != "" {
b.WriteString(n.n)
b.WriteByte('/')
}
b.WriteString(n.m)
if n.t != "" {
b.WriteByte(':')
b.WriteString(n.t)
}
return b.String()
}
func (n Name) GoString() string {
return fmt.Sprintf("<Name %q %q %q %q>", n.h, n.n, n.m, n.t)
}
// cutLastAny is like strings.Cut but scans in reverse for the last character
// in chars. If no character is found, before is the empty string and after is
// s. The returned sep is the byte value of the character in chars if one was
// found; otherwise it is 0.
func cutLastAny(s, chars string) (before, after string, sep byte) {
i := strings.LastIndexAny(s, chars)
if i >= 0 {
return s[:i], s[i+1:], s[i]
}
return "", s, 0
}

View File

@ -0,0 +1,152 @@
package names
import (
"strings"
"testing"
)
func TestParseName(t *testing.T) {
cases := []struct {
in string
want Name
}{
{"", Name{}},
{"m:t", Name{m: "m", t: "t"}},
{"m", Name{m: "m"}},
{"/m", Name{m: "m"}},
{"/n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m", Name{n: "n", m: "m"}},
{strings.Repeat("m", MaxNameLength+1), Name{}},
{"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
{"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
// Invalids
// TODO: {"n:t/m:t", Name{}},
// TODO: {"/h/n/m:t", Name{}},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got := Parse(tt.in)
if got.Compare(tt.want) != 0 {
t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
}
})
}
}
func TestString(t *testing.T) {
cases := []string{
"",
"m:t",
"m:t",
"m",
"n/m",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
// Special cased to "round trip" without the leading slash.
"/m",
"/n/m:t",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
s = strings.TrimPrefix(s, "/")
if g := Parse(s).String(); g != s {
t.Errorf("parse(%q).String() = %q", s, g)
}
})
}
}
func TestParseExtended(t *testing.T) {
cases := []struct {
in string
wantScheme string
wantName Name
wantDigest string
}{
{"", "", Name{}, ""},
{"m", "", Name{m: "m"}, ""},
{"http://m", "http", Name{m: "m"}, ""},
{"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
{"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
scheme, name, digest := ParseExtended(tt.in)
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
}
// Round trip
if got := FormatExtended(scheme, name, digest); got != tt.in {
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
}
})
}
}
func TestMerge(t *testing.T) {
cases := []struct {
a, b string
want string
}{
{"", "", ""},
{"m", "", "m"},
{"", "m", ""},
{"x", "y", "x"},
{"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
{"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
{"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
{"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
a, b := Parse(tt.a), Parse(tt.b)
got := Merge(a, b)
if got.Compare(Parse(tt.want)) != 0 {
t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestParseStringRoundTrip(t *testing.T) {
cases := []string{
"",
"m",
"m:t",
"n/m",
"n/m:t",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
if got := Parse(s).String(); got != s {
t.Errorf("parse(%q).String() = %q", s, got)
}
})
}
}
var junkName Name
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
junkName = Parse("h/n/m:t")
}
}

View File

@ -0,0 +1,52 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package stringsx provides additional string manipulation functions
// that aren't in the standard library's strings package or go4.org/mem.
package stringsx
import (
"unicode"
"unicode/utf8"
)
// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
// like cmp.Compare, but case insensitively.
func CompareFold(a, b string) int {
// Track our position in both strings
ia, ib := 0, 0
for ia < len(a) && ib < len(b) {
ra, wa := nextRuneLower(a[ia:])
rb, wb := nextRuneLower(b[ib:])
if ra < rb {
return -1
}
if ra > rb {
return 1
}
ia += wa
ib += wb
if wa == 0 || wb == 0 {
break
}
}
// If we've reached here, one or both strings are exhausted
// The shorter string is "less than" if they match up to this point
switch {
case ia == len(a) && ib == len(b):
return 0
case ia == len(a):
return -1
default:
return 1
}
}
// nextRuneLower returns the next rune in the string, lowercased, along with its
// original (consumed) width in bytes. If the string is empty, it returns
// (utf8.RuneError, 0)
func nextRuneLower(s string) (r rune, width int) {
r, width = utf8.DecodeRuneInString(s)
return unicode.ToLower(r), width
}

View File

@ -0,0 +1,78 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package stringsx
import (
"cmp"
"strings"
"testing"
)
func TestCompareFold(t *testing.T) {
tests := []struct {
a, b string
}{
// Basic ASCII cases
{"", ""},
{"a", "a"},
{"a", "A"},
{"A", "a"},
{"a", "b"},
{"b", "a"},
{"abc", "ABC"},
{"ABC", "abc"},
{"abc", "abd"},
{"abd", "abc"},
// Length differences
{"abc", "ab"},
{"ab", "abc"},
// Unicode cases
{"世界", "世界"},
{"Hello世界", "hello世界"},
{"世界Hello", "世界hello"},
{"世界", "世界x"},
{"世界x", "世界"},
// Special case folding examples
{"ß", "ss"}, // German sharp s
{"fi", "fi"}, // fi ligature
{"Σ", "σ"}, // Greek sigma
{"İ", "i\u0307"}, // Turkish dotted I
// Mixed cases
{"HelloWorld", "helloworld"},
{"HELLOWORLD", "helloworld"},
{"helloworld", "HELLOWORLD"},
{"HelloWorld", "helloworld"},
{"helloworld", "HelloWorld"},
// Edge cases
{" ", " "},
{"1", "1"},
{"123", "123"},
{"!@#", "!@#"},
}
wants := []int{}
for _, tt := range tests {
got := CompareFold(tt.a, tt.b)
want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
if got != want {
t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
}
wants = append(wants, want)
}
if n := testing.AllocsPerRun(1000, func() {
for i, tt := range tests {
if CompareFold(tt.a, tt.b) != wants[i] {
panic("unexpected")
}
}
}); n > 0 {
t.Errorf("allocs = %v; want 0", int(n))
}
}

View File

@ -0,0 +1,201 @@
// Package syncs provides synchronization primitives.
package syncs
import (
"cmp"
"io"
"sync"
)
var closedChan = func() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}()
// Ticket represents a ticket in a sequence of tickets. The zero value is
// invalid. Use [Line.Take] to get a valid ticket.
//
// A Ticket is not safe for concurrent use.
type Ticket struct {
ahead chan struct{} // ticket ahead of this one
ch chan struct{}
}
// Ready returns a channel that is closed when the ticket before this one is
// done.
//
// It is incorrect to wait on Ready after the ticket is done.
func (t *Ticket) Ready() chan struct{} {
return cmp.Or(t.ahead, closedChan)
}
// Done signals that this ticket is done and that the next ticket in line can
// proceed.
//
// The first call to [Done] unblocks the ticket after it, if any. Subsequent
// calls are no-ops.
func (t *Ticket) Done() {
if t.ch != nil {
close(t.ch)
}
t.ch = nil
}
// Line is an ordered sequence of tickets waiting for their turn to proceed.
//
// To get a ticket use [Line.Take].
// To signal that a ticket is done use [Ticket.Done].
// To wait your turn use [Ticket.Ready].
//
// A Line is not safe for concurrent use.
type Line struct {
last chan struct{} // last ticket in line
}
func (q *Line) Take() *Ticket {
t := &Ticket{
ahead: q.last,
ch: make(chan struct{}),
}
q.last = t.ch
return t
}
// RelayReader implements an [io.WriterTo] that yields the passed
// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
// the order they are taken. Each [io.WriteCloser] blocks until the previous
// one is closed, or a call to [RelayReader.CloseWithError] is made.
//
// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
//
// It is not safe for concurrent use.
type RelayReader struct {
line Line
t *Ticket
w io.Writer
n int64
mu sync.Mutex
err error // set by CloseWithError
closedCh chan struct{} // closed if err is set
}
var (
_ io.Closer = (*RelayReader)(nil)
_ io.WriterTo = (*RelayReader)(nil)
_ io.Reader = (*RelayReader)(nil)
)
func NewRelayReader() *RelayReader {
var q RelayReader
q.closedCh = make(chan struct{})
q.t = q.line.Take()
return &q
}
// CloseWithError terminates the line, unblocking any writer waiting for its
// turn with the error, or [io.EOF] if err is nil. It is safe to call
// [CloseWithError] multiple times and across multiple goroutines.
//
// If the line is already closed, [CloseWithError] is a no-op.
//
// It never returns an error.
func (q *RelayReader) CloseWithError(err error) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.err == nil {
q.err = cmp.Or(q.err, err, io.EOF)
close(q.closedCh)
}
return nil
}
// Close closes the line. Any writer waiting for its turn will be unblocked
// with an [io.ErrClosedPipe] error.
//
// It never returns an error.
func (q *RelayReader) Close() error {
return q.CloseWithError(nil)
}
func (q *RelayReader) closed() <-chan struct{} {
q.mu.Lock()
defer q.mu.Unlock()
return q.closedCh
}
func (q *RelayReader) Read(p []byte) (int, error) {
panic("RelayReader.Read is for show only; use WriteTo")
}
// WriteTo yields the writer w to the first writer in line and blocks until the
// first call to [Close].
//
// It is safe to call [Take] concurrently with [WriteTo].
func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
select {
case <-q.closed():
return 0, io.ErrClosedPipe
default:
}
// We have a destination writer; let the relay begin.
q.w = dst
q.t.Done()
<-q.closed()
return q.n, nil
}
// Take returns a writer that will be passed to the next writer in line.
//
// It is not safe for use across multiple goroutines.
func (q *RelayReader) Take() io.WriteCloser {
return &relayWriter{q: q, t: q.line.Take()}
}
type relayWriter struct {
q *RelayReader
t *Ticket
ready bool
}
var _ io.StringWriter = (*relayWriter)(nil)
// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
// writer is ready. It returns io.ErrClosedPipe if the line is closed before
// the writer is ready.
func (w *relayWriter) Write(p []byte) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
n, err := w.q.w.Write(p)
w.q.n += int64(n)
return n, err
}
func (w *relayWriter) WriteString(s string) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
return io.WriteString(w.q.w, s)
}
// Close signals that the writer is done, unblocking the next writer in line.
func (w *relayWriter) Close() error {
w.t.Done()
return nil
}
func (t *relayWriter) awaitTurn() (ok bool) {
if t.ready {
return true
}
select {
case <-t.t.Ready():
t.ready = true
return true
case <-t.q.closed():
return false
}
}

View File

@ -0,0 +1,65 @@
package syncs
import (
"bytes"
"io"
"math/rand/v2"
"testing"
"testing/synctest"
)
func TestPipelineReadWriterTo(t *testing.T) {
for range 10 {
synctest.Run(func() {
q := NewRelayReader()
tickets := []struct {
io.WriteCloser
s string
}{
{q.Take(), "you"},
{q.Take(), " say hi,"},
{q.Take(), " and "},
{q.Take(), "I say "},
{q.Take(), "hello"},
}
rand.Shuffle(len(tickets), func(i, j int) {
tickets[i], tickets[j] = tickets[j], tickets[i]
})
var g Group
for i, t := range tickets {
g.Go(func() {
defer t.Close()
if i%2 == 0 {
// Use [relayWriter.WriteString]
io.WriteString(t.WriteCloser, t.s)
} else {
t.Write([]byte(t.s))
}
})
}
var got bytes.Buffer
var copyErr error // checked at end
g.Go(func() {
_, copyErr = io.Copy(&got, q)
})
synctest.Wait()
q.Close()
g.Wait()
if copyErr != nil {
t.Fatal(copyErr)
}
want := "you say hi, and I say hello"
if got.String() != want {
t.Fatalf("got %q, want %q", got.String(), want)
}
})
}
}

View File

@ -0,0 +1,41 @@
package syncs
import (
"sync"
"sync/atomic"
)
// Group is a [sync.WaitGroup] with a Go method.
type Group struct {
wg sync.WaitGroup
n atomic.Int64
}
func (g *Group) Go(f func()) {
g.wg.Add(1)
go func() {
g.n.Add(1) // Now we are running
defer func() {
g.wg.Done()
g.n.Add(-1) // Now we are done
}()
f()
}()
}
// Running returns the number of goroutines that are currently running.
//
// If a call to [Running] returns zero, and a call to [Wait] is made without
// any calls to [Go], then [Wait] will return immediately. This is true even if
// a goroutine is started and finishes between the two calls.
//
// It is possible for [Running] to return non-zero and for [Wait] to return
// immediately. This can happen if the all running goroutines finish between
// the two calls.
func (g *Group) Running() int64 {
return g.n.Load()
}
func (g *Group) Wait() {
g.wg.Wait()
}

View File

@ -0,0 +1,74 @@
package testutil
import (
"os"
"path/filepath"
"testing"
"time"
)
// Check calls t.Fatal(err) if err is not nil.
func Check(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
// CheckFunc exists so other packages do not need to invent their own type for
// taking a Check function.
type CheckFunc func(err error)
// Checker returns a check function that
// calls t.Fatal if err is not nil.
func Checker(t *testing.T) (check func(err error)) {
return func(err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
}
// StopPanic runs f but silently recovers from any panic f causes.
// The normal usage is:
//
// testutil.StopPanic(func() {
// callThatShouldPanic()
// t.Errorf("callThatShouldPanic did not panic")
// })
func StopPanic(f func()) {
defer func() { recover() }()
f()
}
// CheckTime calls t.Fatalf if got != want. Included in the error message is
// want.Sub(got) to help diagnose the difference, along with their values in
// UTC.
func CheckTime(t *testing.T, got, want time.Time) {
t.Helper()
if !got.Equal(want) {
t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
}
}
// WriteFile writes data to a file named name. It makes the directory if it
// doesn't exist and sets the file mode to perm.
//
// The name must be a relative path and must not contain .. or start with a /;
// otherwise WriteFile will panic.
func WriteFile[S []byte | string](t testing.TB, name string, data S) {
t.Helper()
if filepath.IsAbs(name) {
t.Fatalf("WriteFile: name must be a relative path, got %q", name)
}
name = filepath.Clean(name)
dir := filepath.Dir(name)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
t.Fatal(err)
}
}

View File

@ -0,0 +1,118 @@
// Package manifest provides documentation for the Ollama manifest format.
// This package contains no code.
//
// # Manifests
//
// A manifest is a JSON object that describes a model. The JSON object has a
// single field "layers" which is a list of layers that make up the model. Each
// layer has the following fields:
//
// A layer is a single, logical unit of a model. Layers are stored in the cache
// as files with the name of the digest of the layer. Layers are pushed and
// pulled from the registry as blobs.
//
// A layer is represented as a JSON object with the following fields:
//
// - "digest": The digest of the layer.
// - "mediaType": The media type of the layer.
// - "size": The size of the layer in bytes.
//
// Layers are typically stored in a blob store, such as a registry, and are
// referenced by their digest. This package does not define how layers are
// stored or retrieved.
//
// # Configuration Layer
//
// The configuration of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.config; type=<type>
//
// The "type" parameter in the media type specifies the format of the
// configuration (e.g., "safetensor" or "gguf").
//
// There may be only one configuration layer in a model.
//
// # Template Layer
//
// The model template is a layer with the media type:
//
// application/vnd.ollama.image.template; [name=<name>]
//
// The "name" parameter in the media type specifies the name of the template as
// for lookup at runtime. The name is optional and may be omitted. If omitted,
// the template is the default template for the model.
//
// # Tensor Layers
//
// The tensors of a model are represented as layers with the media type:
//
// application/vnd.ollama.image.tensor; name=<name>; dtype=<dtype>; shape=<shape>
//
// The "name" parameter in the media type specifies the name of the tensor as
// defined in the model's configuration and are bound only by the rules for
// names as defined in the configuration format, as represented by the
// configuration's "type".
//
// The "dtype" parameter in the media type specifies the data type of the tensor
// as a string.
//
// TODO: Define more specifically how to represent data types as strings.
//
// The "shape" parameter in the media type specifies the shape of the tensor as
// a comma-separated list of integers; one per dimension.
//
// # Tokenization Layers
//
// The tokenization of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer
//
// The configuration of the tokenizer is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer.config
//
// # Miscellaneous Layers
//
// These extra layer mime types are reserved:
//
// application/vnd.ollama.image.license
//
// This layer contains one of the many licenses for the model in plain text.
//
// # Example Manifest
//
// The following is an example manifest containing a configuration, a model
// template, and two tensors (digests shortened for brevity):
//
// {
// "layers": [{
// "digest": "sha256:a...",
// "mediaType": "application/vnd.ollama.image.config; type=safetensors",
// "size": 1234
// },{
// "digest": "sha256:b...",
// "mediaType": "application/vnd.ollama.image.template",
// "size": 5678
// },{
// "digest": "sha256:c...",
// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3",
// "size": 9012
// },{
// "digest": "sha256:d...",
// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6",
// "size": 3456
// }]
// }
//
// # Legacy Media Types
//
// The appliaction/vnd.ollama.image.model media type is deprecated, but will
// remain supported for backwards compatibility, for some undefined amount of
// time. New models should use the media types defined above.
//
// # Reserved media types
//
// The media type prefix "application/vnd.ollama.image." is reserved for
// defining new media types for layers known to Ollama. Currently, all other
// prefixes are ignored by official Ollama registry clients.
package manifest