diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8af8812fb..56a2cc4fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -147,6 +147,7 @@ jobs: runs-on: ${{ matrix.os }} env: CGO_ENABLED: '1' + GOEXPERIMENT: 'synctest' steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 diff --git a/.gitignore b/.gitignore index 551abec87..3a2af0bd1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ .swp dist build -ollama .cache *.exe .idea @@ -14,3 +13,4 @@ test_data __debug_bin* llama/build llama/vendor +/ollama diff --git a/.golangci.yaml b/.golangci.yaml index 9d59fd6c0..9bb9786a8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -6,8 +6,6 @@ linters: - bidichk - bodyclose - containedctx - - contextcheck - - errcheck - gocheckcompilerdirectives - gofmt - gofumpt @@ -23,10 +21,11 @@ linters: - staticcheck - tenv - unconvert - - unused - - usestdlibvars - wastedassign - whitespace + disable: + - usestdlibvars + - errcheck linters-settings: staticcheck: checks: @@ -39,5 +38,4 @@ severity: - gofmt - goimports - intrange - - usestdlibvars severity: info diff --git a/server/internal/cache/blob/cache.go b/server/internal/cache/blob/cache.go new file mode 100644 index 000000000..f0b0760f1 --- /dev/null +++ b/server/internal/cache/blob/cache.go @@ -0,0 +1,544 @@ +// Package blob implements a content-addressable disk cache for blobs and +// manifests. +package blob + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "hash" + "io" + "io/fs" + "iter" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ollama/ollama/server/internal/internal/names" +) + +// Entry contains metadata about a blob in the cache. +type Entry struct { + Digest Digest + Size int64 + Time time.Time // when added to the cache +} + +// DiskCache caches blobs and manifests on disk. +// +// The cache is rooted at a directory, which is created if it does not exist. +// +// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the +// "manifests" subdirectory. A example directory structure might look like: +// +// / +// blobs/ +// sha256- - +// manifests/ +// / +// / +// / +// - +// +// The cache is safe for concurrent use. +// +// Name casing is preserved in the cache, but is not significant when resolving +// names. For example, "Foo" and "foo" are considered the same name. +// +// The cache is not safe for concurrent use. It guards concurrent writes, but +// does not prevent duplicated effort. Because blobs are immutable, duplicate +// writes should result in the same file being written to disk. +type DiskCache struct { + // Dir specifies the top-level directory where blobs and manifest + // pointers are stored. + dir string + now func() time.Time + + testHookBeforeFinalWrite func(f *os.File) +} + +// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))). +func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error { + return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data))) +} + +// Open opens a cache rooted at the given directory. If the directory does not +// exist, it is created. If the directory is not a directory, an error is +// returned. +func Open(dir string) (*DiskCache, error) { + if dir == "" { + return nil, errors.New("blob: empty directory name") + } + + info, err := os.Stat(dir) + if err == nil && !info.IsDir() { + return nil, fmt.Errorf("%q is not a directory", dir) + } + if err := os.MkdirAll(dir, 0o777); err != nil { + return nil, err + } + + subdirs := []string{"blobs", "manifests"} + for _, subdir := range subdirs { + if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil { + return nil, err + } + } + + // TODO(bmizerany): support shards + c := &DiskCache{ + dir: dir, + now: time.Now, + } + return c, nil +} + +func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) { + f, err := os.Open(filename) + if err != nil { + return nil, Digest{}, err + } + defer f.Close() + + h := sha256.New() + r := io.TeeReader(f, h) + data, err = io.ReadAll(io.LimitReader(r, limit)) + if err != nil { + return nil, Digest{}, err + } + var d Digest + h.Sum(d.sum[:0]) + return data, d, nil +} + +//lint:ignore U1000 used for debugging purposes as needed in tests +var debug = false + +// debugger returns a function that can be used to add a step to the error message. +// The error message will be a list of steps that were taken before the error occurred. +// The steps are added in the order they are called. +// +// To set the error message, call the returned function with an empty string. +// +//lint:ignore U1000 used for debugging purposes as needed in tests +func debugger(err *error) func(step string) { + if !debug { + return func(string) {} + } + var steps []string + return func(step string) { + if step == "" && *err != nil { + *err = fmt.Errorf("%q: %w", steps, *err) + return + } + steps = append(steps, step) + if len(steps) > 100 { + // shift hints in case of a bug that causes a lot of hints + copy(steps, steps[1:]) + steps = steps[:100] + } + } +} + +// Resolve resolves a name to a digest. The name is expected to +// be in either of the following forms: +// +// @ +// +// +// +// If a digest is provided, it is returned as is and nothing else happens. +// +// If a name is provided for a manifest that exists in the cache, the digest +// of the manifest is returned. If there is no manifest in the cache, it +// returns [fs.ErrNotExist]. +// +// To cover the case where a manifest may change without the cache knowing +// (e.g. it was reformatted or modified by hand), the manifest data read and +// hashed is passed to a PutBytes call to ensure that the manifest is in the +// blob store. This is done to ensure that future calls to [Get] succeed in +// these cases. +// +// TODO(bmizerany): Move Links/Resolve/etc. out of this package. +func (c *DiskCache) Resolve(name string) (Digest, error) { + name, digest := splitNameDigest(name) + if digest != "" { + return ParseDigest(digest) + } + + // We want to address manifests files by digest using Get. That requires + // them to be blobs. This cannot be directly accomplished by looking in + // the blob store because manifests can change without Ollama knowing + // (e.g. a user modifies a manifests by hand then pushes it to update + // their model). We also need to support the blob caches inherited from + // older versions of Ollama, which do not store manifests in the blob + // store, so for these cases, we need to handle adding the manifests to + // the blob store, just in time. + // + // So now we read the manifests file, hash it, and copy it to the blob + // store if it's not already there. + // + // This should be cheap because manifests are small, and accessed + // infrequently. + file, err := c.manifestPath(name) + if err != nil { + return Digest{}, err + } + + data, d, err := readAndSum(file, 1<<20) + if err != nil { + return Digest{}, err + } + + // Ideally we'd read the "manifest" file as a manifest to the blob file, + // but we are not changing this yet, so copy the manifest to the blob + // store so it can be addressed by digest subsequent calls to Get. + if err := PutBytes(c, d, data); err != nil { + return Digest{}, err + } + return d, nil +} + +// Put writes a new blob to the cache, identified by its digest. The operation +// reads content from r, which must precisely match both the specified size and +// digest. +// +// Concurrent write safety is achieved through file locking. The implementation +// guarantees write integrity by enforcing size limits and content validation +// before allowing the file to reach its final state. +func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error { + return c.copyNamedFile(c.GetFile(d), r, d, size) +} + +// Import imports a blob from the provided reader into the cache. It reads the +// entire content of the reader, calculates its digest, and stores it in the +// cache. +// +// Import should be considered unsafe for use with untrusted data, such as data +// read from a network. The caller is responsible for ensuring the integrity of +// the data being imported. +func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) { + // users that want to change the temp dir can set TEMPDIR. + f, err := os.CreateTemp("", "blob-") + if err != nil { + return Digest{}, err + } + defer os.Remove(f.Name()) + + // Copy the blob to a temporary file. + h := sha256.New() + r = io.TeeReader(r, h) + n, err := io.Copy(f, r) + if err != nil { + return Digest{}, err + } + if n != size { + return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n) + } + + // Check the digest. + var d Digest + h.Sum(d.sum[:0]) + if err := f.Close(); err != nil { + return Digest{}, err + } + name := c.GetFile(d) + // Rename the temporary file to the final file. + if err := os.Rename(f.Name(), name); err != nil { + return Digest{}, err + } + os.Chtimes(name, c.now(), c.now()) // mainly for tests + return d, nil +} + +// Get retrieves a blob from the cache using the provided digest. The operation +// fails if the digest is malformed or if any errors occur during blob +// retrieval. +func (c *DiskCache) Get(d Digest) (Entry, error) { + name := c.GetFile(d) + info, err := os.Stat(name) + if err != nil { + return Entry{}, err + } + if info.Size() == 0 { + return Entry{}, fs.ErrNotExist + } + return Entry{ + Digest: d, + Size: info.Size(), + Time: info.ModTime(), + }, nil +} + +// Link creates a symbolic reference in the cache that maps the provided name +// to a blob identified by its digest, making it retrievable by name using +// [Resolve]. +// +// It returns an error if either the name or digest is invalid, or if link +// creation encounters any issues. +func (c *DiskCache) Link(name string, d Digest) error { + manifest, err := c.manifestPath(name) + if err != nil { + return err + } + f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0) + if err != nil { + return err + } + defer f.Close() + + // TODO(bmizerany): test this happens only if the blob was found to + // avoid leaving debris + if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil { + return err + } + + info, err := f.Stat() + if err != nil { + return err + } + + // Copy manifest to cache directory. + return c.copyNamedFile(manifest, f, d, info.Size()) +} + +// Unlink removes the any link for name. If the link does not exist, nothing +// happens, and no error is returned. +// +// It returns an error if the name is invalid or if the link removal encounters +// any issues. +func (c *DiskCache) Unlink(name string) error { + manifest, err := c.manifestPath(name) + if err != nil { + return err + } + err = os.Remove(manifest) + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err +} + +// GetFile returns the absolute path to the file, in the cache, for the given +// digest. It does not check if the file exists. +// +// The returned path should not be stored, used outside the lifetime of the +// cache, or interpreted in any way. +func (c *DiskCache) GetFile(d Digest) string { + filename := fmt.Sprintf("sha256-%x", d.sum) + return absJoin(c.dir, "blobs", filename) +} + +// Links returns a sequence of links in the cache in lexical order. +func (c *DiskCache) Links() iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + for path, err := range c.links() { + if err != nil { + yield("", err) + return + } + if !yield(pathToName(path), nil) { + return + } + } + } +} + +// pathToName converts a path to a name. It is the inverse of nameToPath. The +// path is assumed to be in filepath.ToSlash format. +func pathToName(s string) string { + s = strings.TrimPrefix(s, "manifests/") + rr := []rune(s) + for i := len(rr) - 1; i > 0; i-- { + if rr[i] == '/' { + rr[i] = ':' + return string(rr) + } + } + return s +} + +// manifestPath finds the first manifest file on disk that matches the given +// name using a case-insensitive comparison. If no manifest file is found, it +// returns the path where the manifest file would be if it existed. +// +// If two manifest files exists on disk that match the given name using a +// case-insensitive comparison, the one that sorts first, lexically, is +// returned. +func (c *DiskCache) manifestPath(name string) (string, error) { + np, err := nameToPath(name) + if err != nil { + return "", err + } + + maybe := filepath.Join("manifests", np) + for l, err := range c.links() { + if err != nil { + return "", err + } + if strings.EqualFold(maybe, l) { + return filepath.Join(c.dir, l), nil + } + } + return filepath.Join(c.dir, maybe), nil +} + +// links returns a sequence of links in the cache in lexical order. +func (c *DiskCache) links() iter.Seq2[string, error] { + // TODO(bmizerany): reuse empty dirnames if exist + return func(yield func(string, error) bool) { + fsys := os.DirFS(c.dir) + manifests, err := fs.Glob(fsys, "manifests/*/*/*/*") + if err != nil { + yield("", err) + return + } + for _, manifest := range manifests { + if !yield(manifest, nil) { + return + } + } + } +} + +type checkWriter struct { + d Digest + size int64 + n int64 + h hash.Hash + f *os.File + err error + + testHookBeforeFinalWrite func(*os.File) +} + +func (w *checkWriter) seterr(err error) error { + if w.err == nil { + w.err = err + } + return err +} + +// Write writes p to the underlying hash and writer. The last write to the +// underlying writer is guaranteed to be the last byte of p as verified by the +// hash. +func (w *checkWriter) Write(p []byte) (int, error) { + _, err := w.h.Write(p) + if err != nil { + return 0, w.seterr(err) + } + nextSize := w.n + int64(len(p)) + if nextSize == w.size { + // last write. check hash. + sum := w.h.Sum(nil) + if !bytes.Equal(sum, w.d.sum[:]) { + return 0, w.seterr(fmt.Errorf("file content changed underfoot")) + } + if w.testHookBeforeFinalWrite != nil { + w.testHookBeforeFinalWrite(w.f) + } + } + if nextSize > w.size { + return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size)) + } + n, err := w.f.Write(p) + w.n += int64(n) + return n, w.seterr(err) +} + +// copyNamedFile copies file into name, expecting it to have the given Digest +// and size, if that file is not present already. +func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error { + info, err := os.Stat(name) + if err == nil && info.Size() == size { + // File already exists with correct size. This is good enough. + // We can skip expensive hash checks. + // + // TODO: Do the hash check, but give caller a way to skip it. + return nil + } + + // Copy file to cache directory. + mode := os.O_RDWR | os.O_CREATE + if err == nil && info.Size() > size { // shouldn't happen but fix in case + mode |= os.O_TRUNC + } + f, err := os.OpenFile(name, mode, 0o666) + if err != nil { + return err + } + defer f.Close() + if size == 0 { + // File now exists with correct size. + // Only one possible zero-length file, so contents are OK too. + // Early return here makes sure there's a "last byte" for code below. + return nil + } + + // From here on, if any of the I/O writing the file fails, + // we make a best-effort attempt to truncate the file f + // before returning, to avoid leaving bad bytes in the file. + + // Copy file to f, but also into h to double-check hash. + cw := &checkWriter{ + d: out, + size: size, + h: sha256.New(), + f: f, + testHookBeforeFinalWrite: c.testHookBeforeFinalWrite, + } + n, err := io.Copy(cw, file) + if err != nil { + f.Truncate(0) + return err + } + if n < size { + f.Truncate(0) + return io.ErrUnexpectedEOF + } + + if err := f.Close(); err != nil { + // Data might not have been written, + // but file may look like it is the right size. + // To be extra careful, remove cached file. + os.Remove(name) + return err + } + os.Chtimes(name, c.now(), c.now()) // mainly for tests + + return nil +} + +func splitNameDigest(s string) (name, digest string) { + i := strings.LastIndexByte(s, '@') + if i < 0 { + return s, "" + } + return s[:i], s[i+1:] +} + +var errInvalidName = errors.New("invalid name") + +func nameToPath(name string) (_ string, err error) { + if strings.Contains(name, "@") { + // TODO(bmizerany): HACK: Fix names.Parse to validate. + // TODO(bmizerany): merge with default parts (maybe names.Merge(a, b)) + return "", errInvalidName + } + n := names.Parse(name) + if !n.IsFullyQualified() { + return "", errInvalidName + } + return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil +} + +func absJoin(pp ...string) string { + abs, err := filepath.Abs(filepath.Join(pp...)) + if err != nil { + // Likely a bug bug or a bad OS problem. Just panic. + panic(err) + } + return abs +} diff --git a/server/internal/cache/blob/cache_test.go b/server/internal/cache/blob/cache_test.go new file mode 100644 index 000000000..704542ea3 --- /dev/null +++ b/server/internal/cache/blob/cache_test.go @@ -0,0 +1,685 @@ +package blob + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "slices" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/internal/testutil" +) + +func init() { + debug = true +} + +var epoch = func() time.Time { + d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + if d.IsZero() { + panic("time zero") + } + return d +}() + +func TestOpenErrors(t *testing.T) { + exe, err := os.Executable() + if err != nil { + panic(err) + } + + cases := []struct { + dir string + err string + }{ + {t.TempDir(), ""}, + {"", "empty directory name"}, + {exe, "not a directory"}, + } + + for _, tt := range cases { + t.Run(tt.dir, func(t *testing.T) { + _, err := Open(tt.dir) + if tt.err == "" { + if err != nil { + t.Fatal(err) + } + return + } + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), tt.err) { + t.Fatalf("err = %v, want %q", err, tt.err) + } + }) + } +} + +func TestGetFile(t *testing.T) { + t.Chdir(t.TempDir()) + + c, err := Open(".") + if err != nil { + t.Fatal(err) + } + + d := mkdigest("1") + got := c.GetFile(d) + cleaned := filepath.Clean(got) + if cleaned != got { + t.Fatalf("got is unclean: %q", got) + } + if !filepath.IsAbs(got) { + t.Fatal("got is not absolute") + } + abs, _ := filepath.Abs(c.dir) + if !strings.HasPrefix(got, abs) { + t.Fatalf("got is not local to %q", c.dir) + } +} + +func TestBasic(t *testing.T) { + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + now := epoch + c.now = func() time.Time { return now } + + checkEntry := entryChecker(t, c) + checkFailed := func(err error) { + if err == nil { + t.Helper() + t.Fatal("expected error") + } + } + + _, err = c.Resolve("invalid") + checkFailed(err) + + _, err = c.Resolve("h/n/m:t") + checkFailed(err) + + dx := mkdigest("x") + + d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx)) + if err != nil { + t.Fatal(err) + } + if d != dx { + t.Fatalf("d = %v, want %v", d, dx) + } + + _, err = c.Get(Digest{}) + checkFailed(err) + + // not committed yet + _, err = c.Get(dx) + checkFailed(err) + + err = PutBytes(c, dx, "!") + checkFailed(err) + + err = PutBytes(c, dx, "x") + if err != nil { + t.Fatal(err) + } + checkEntry(dx, 1, now) + + t0 := now + now = now.Add(1*time.Hour + 1*time.Minute) + err = PutBytes(c, dx, "x") + if err != nil { + t.Fatal(err) + } + + // check not updated + checkEntry(dx, 1, t0) +} + +type sleepFunc func(d time.Duration) time.Time + +func openTester(t *testing.T) (*DiskCache, sleepFunc) { + t.Helper() + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + now := epoch + c.now = func() time.Time { return now } + return c, func(d time.Duration) time.Time { + now = now.Add(d) + return now + } +} + +func TestManifestPath(t *testing.T) { + check := testutil.Checker(t) + + c, sleep := openTester(t) + + d1 := mkdigest("1") + err := PutBytes(c, d1, "1") + check(err) + + err = c.Link("h/n/m:t", d1) + check(err) + + t0 := sleep(0) + sleep(1 * time.Hour) + err = c.Link("h/n/m:t", d1) // nop expected + check(err) + + file := must(c.manifestPath("h/n/m:t")) + info, err := os.Stat(file) + check(err) + testutil.CheckTime(t, info.ModTime(), t0) +} + +func TestManifestExistsWithoutBlob(t *testing.T) { + t.Chdir(t.TempDir()) + + check := testutil.Checker(t) + + c, err := Open(".") + check(err) + + checkEntry := entryChecker(t, c) + + man := must(c.manifestPath("h/n/m:t")) + os.MkdirAll(filepath.Dir(man), 0o777) + testutil.WriteFile(t, man, "1") + + got, err := c.Resolve("h/n/m:t") + check(err) + + want := mkdigest("1") + if got != want { + t.Fatalf("got = %v, want %v", got, want) + } + + e, err := c.Get(got) + check(err) + checkEntry(got, 1, e.Time) +} + +func TestPut(t *testing.T) { + c, sleep := openTester(t) + + check := testutil.Checker(t) + checkEntry := entryChecker(t, c) + + d := mkdigest("hello, world") + err := PutBytes(c, d, "hello") + if err == nil { + t.Fatal("expected error") + } + + got, err := c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("expected error, got %v", got) + } + + // Put a valid blob + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob with content that does not hash to the digest + err = PutBytes(c, d, "hello") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Put the valid blob back and check it + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob that errors during Read + err = c.Put(d, &errOnBangReader{s: "!"}, 1) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Put valid blob back and check it + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob with mismatched size + err = c.Put(d, strings.NewReader("hello, world"), 11) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Final byte does not match the digest (testing commit phase) + err = PutBytes(c, d, "hello, world$") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + reset := c.setTestHookBeforeFinalWrite(func(f *os.File) { + // change mode to read-only + f.Truncate(0) + f.Chmod(0o400) + f.Close() + f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { f1.Close() }) + *f = *f1 + }) + defer reset() + + err = PutBytes(c, d, "hello, world") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + reset() +} + +func TestImport(t *testing.T) { + c, _ := openTester(t) + + checkEntry := entryChecker(t, c) + + want := mkdigest("x") + got, err := c.Import(strings.NewReader("x"), 1) + if err != nil { + t.Fatal(err) + } + if want != got { + t.Fatalf("digest = %v, want %v", got, want) + } + checkEntry(want, 1, epoch) + + got, err = c.Import(strings.NewReader("x"), 1) + if err != nil { + t.Fatal(err) + } + if want != got { + t.Fatalf("digest = %v, want %v", got, want) + } + checkEntry(want, 1, epoch) +} + +func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) { + old := c.testHookBeforeFinalWrite + c.testHookBeforeFinalWrite = h + return func() { c.testHookBeforeFinalWrite = old } +} + +func TestPutGetZero(t *testing.T) { + c, sleep := openTester(t) + + check := testutil.Checker(t) + checkEntry := entryChecker(t, c) + + d := mkdigest("x") + err := PutBytes(c, d, "x") + check(err) + checkEntry(d, 1, sleep(0)) + + err = os.Truncate(c.GetFile(d), 0) + check(err) + + _, err = c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } +} + +func TestPutZero(t *testing.T) { + c, _ := openTester(t) + d := mkdigest("x") + err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content) + testutil.Check(t, err) + checkNotExists(t, c, d) +} + +func TestCommit(t *testing.T) { + check := testutil.Checker(t) + + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + checkEntry := entryChecker(t, c) + + now := epoch + c.now = func() time.Time { return now } + + d1 := mkdigest("1") + err = c.Link("h/n/m:t", d1) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } + + err = PutBytes(c, d1, "1") + check(err) + + err = c.Link("h/n/m:t", d1) + check(err) + + got, err := c.Resolve("h/n/m:t") + check(err) + if got != d1 { + t.Fatalf("d = %v, want %v", got, d1) + } + + // commit again, more than 1 byte + d2 := mkdigest("22") + err = PutBytes(c, d2, "22") + check(err) + err = c.Link("h/n/m:t", d2) + check(err) + checkEntry(d2, 2, now) + + filename := must(c.manifestPath("h/n/m:t")) + data, err := os.ReadFile(filename) + check(err) + if string(data) != "22" { + t.Fatalf("data = %q, want %q", data, "22") + } + + t0 := now + now = now.Add(1 * time.Hour) + err = c.Link("h/n/m:t", d2) // same contents; nop + check(err) + info, err := os.Stat(filename) + check(err) + testutil.CheckTime(t, info.ModTime(), t0) +} + +func TestManifestInvalidBlob(t *testing.T) { + c, _ := openTester(t) + d := mkdigest("1") + err := c.Link("h/n/m:t", d) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + err = PutBytes(c, d, "1") + testutil.Check(t, err) + err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666) + if err != nil { + t.Fatal(err) + } + + err = c.Link("h/n/m:t", d) + if !strings.Contains(err.Error(), "underfoot") { + t.Fatalf("err = %v, want error to contain %q", err, "underfoot") + } +} + +func TestManifestNameReuse(t *testing.T) { + t.Run("case-insensitive", func(t *testing.T) { + // This should run on all file system types. + testManifestNameReuse(t) + }) + t.Run("case-sensitive", func(t *testing.T) { + useCaseInsensitiveTempDir(t) + testManifestNameReuse(t) + }) +} + +func testManifestNameReuse(t *testing.T) { + check := testutil.Checker(t) + + c, _ := openTester(t) + + d1 := mkdigest("1") + err := PutBytes(c, d1, "1") + check(err) + err = c.Link("h/n/m:t", d1) + check(err) + + d2 := mkdigest("22") + err = PutBytes(c, d2, "22") + check(err) + err = c.Link("H/N/M:T", d2) + check(err) + + var g [2]Digest + g[0], err = c.Resolve("h/n/m:t") + check(err) + g[1], err = c.Resolve("H/N/M:T") + check(err) + + w := [2]Digest{d2, d2} + if g != w { + t.Fatalf("g = %v, want %v", g, w) + } + + var got []string + for l, err := range c.links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + want := []string{"manifests/h/n/m/t"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } + + // relink with different case + err = c.Unlink("h/n/m:t") + check(err) + err = c.Link("h/n/m:T", d1) + check(err) + + got = got[:0] + for l, err := range c.links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + + // we should have only one link that is same case as the last link + want = []string{"manifests/h/n/m/T"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } +} + +func TestManifestFile(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", ""}, + + // valid names + {"h/n/m:t", "/manifests/h/n/m/t"}, + {"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"}, + + {"%/%/%/%", ""}, + + // already a path + {"h/n/m/t", ""}, + + // refs are not names + {"h/n/m:t@sha256-1", ""}, + {"m@sha256-1", ""}, + {"n/m:t@sha256-1", ""}, + } + + c, _ := openTester(t) + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got, err := c.manifestPath(tt.in) + if err != nil && tt.want != "" { + t.Fatalf("unexpected error: %v", err) + } + if err == nil && tt.want == "" { + t.Fatalf("expected error") + } + dir := filepath.ToSlash(c.dir) + got = filepath.ToSlash(got) + got = strings.TrimPrefix(got, dir) + if got != tt.want { + t.Fatalf("got = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNames(t *testing.T) { + c, _ := openTester(t) + check := testutil.Checker(t) + + check(PutBytes(c, mkdigest("1"), "1")) + check(PutBytes(c, mkdigest("2"), "2")) + + check(c.Link("h/n/m:t", mkdigest("1"))) + check(c.Link("h/n/m:u", mkdigest("2"))) + + var got []string + for l, err := range c.Links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + want := []string{"h/n/m:t", "h/n/m:u"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } +} + +func mkdigest(s string) Digest { + return Digest{sha256.Sum256([]byte(s))} +} + +func checkNotExists(t *testing.T, c *DiskCache, d Digest) { + t.Helper() + _, err := c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } +} + +func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) { + t.Helper() + return func(d Digest, size int64, mod time.Time) { + t.Helper() + t.Run("checkEntry:"+d.String(), func(t *testing.T) { + t.Helper() + + defer func() { + if t.Failed() { + dumpCacheContents(t, c) + } + }() + + e, err := c.Get(d) + if size == 0 && errors.Is(err, fs.ErrNotExist) { + err = nil + } + if err != nil { + t.Fatal(err) + } + if e.Digest != d { + t.Errorf("e.Digest = %v, want %v", e.Digest, d) + } + if e.Size != size { + t.Fatalf("e.Size = %v, want %v", e.Size, size) + } + + testutil.CheckTime(t, e.Time, mod) + info, err := os.Stat(c.GetFile(d)) + if err != nil { + t.Fatal(err) + } + if info.Size() != size { + t.Fatalf("info.Size = %v, want %v", info.Size(), size) + } + testutil.CheckTime(t, info.ModTime(), mod) + }) + } +} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +func TestNameToPath(t *testing.T) { + _, err := nameToPath("h/n/m:t") + if err != nil { + t.Fatal(err) + } +} + +type errOnBangReader struct { + s string + n int +} + +func (e *errOnBangReader) Read(p []byte) (int, error) { + if len(p) < 1 { + return 0, io.ErrShortBuffer + } + if e.n >= len(p) { + return 0, io.EOF + } + if e.s[e.n] == '!' { + return 0, errors.New("bang") + } + p[0] = e.s[e.n] + e.n++ + return 1, nil +} + +func dumpCacheContents(t *testing.T, c *DiskCache) { + t.Helper() + + var b strings.Builder + fsys := os.DirFS(c.dir) + fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + t.Helper() + + if err != nil { + return err + } + info, err := d.Info() + if err != nil { + return err + } + + // Format like ls: + // + // ; ls -la + // drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123 + // drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m + + fmt.Fprintf(&b, " %s % 4d %s %s\n", + info.Mode(), + info.Size(), + info.ModTime().Format("Jan 2 15:04"), + path, + ) + return nil + }) + t.Log() + t.Logf("cache contents:\n%s", b.String()) +} diff --git a/server/internal/cache/blob/casecheck_test.go b/server/internal/cache/blob/casecheck_test.go new file mode 100644 index 000000000..f0842ef91 --- /dev/null +++ b/server/internal/cache/blob/casecheck_test.go @@ -0,0 +1,93 @@ +package blob + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func isCaseSensitive(dir string) bool { + defer func() { + os.Remove(filepath.Join(dir, "_casecheck")) + }() + + exists := func(file string) bool { + _, err := os.Stat(file) + return err == nil + } + + file := filepath.Join(dir, "_casecheck") + FILE := filepath.Join(dir, "_CASECHECK") + if exists(file) || exists(FILE) { + panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir)) + } + + err := os.WriteFile(file, nil, 0o666) + if err != nil { + panic(err) + } + + return !exists(FILE) +} + +func isCI() bool { + return os.Getenv("CI") != "" +} + +const volumeHint = ` + + Unable to locate case-insensitive TMPDIR on darwin. + + To run tests, create the case-insensitive volume /Volumes/data: + + $ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data + + or run with: + + CI=1 go test ./... + +` + +// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory +// can find one, otherwise it skips the test if the CI environment variable is +// set, or GOOS is not darwin. +func useCaseInsensitiveTempDir(t *testing.T) bool { + if isCaseSensitive(os.TempDir()) { + // Use the default temp dir if it is already case-sensitive. + return true + } + if runtime.GOOS == "darwin" { + // If darwin, check for the special case-sensitive volume and + // use it if available. + const volume = "/Volumes/data" + _, err := os.Stat(volume) + if err == nil { + tmpdir := filepath.Join(volume, "tmp") + os.MkdirAll(tmpdir, 0o700) + t.Setenv("TMPDIR", tmpdir) + return true + } + if isCI() { + // Special case darwin in CI; it is not case-sensitive + // by default, and we will be testing other platforms + // that are case-sensitive, so we'll have the test + // being skipped covered there. + t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.") + } + } + + if !isCI() { + // Require devs to always tests with a case-insensitive TMPDIR. + + // TODO(bmizerany): Print platform-specific instructions or + // link to docs on that topic. + lines := strings.Split(volumeHint, "\n") + for _, line := range lines { + t.Log(line) + } + } + return false +} diff --git a/server/internal/cache/blob/digest.go b/server/internal/cache/blob/digest.go new file mode 100644 index 000000000..723ba222c --- /dev/null +++ b/server/internal/cache/blob/digest.go @@ -0,0 +1,95 @@ +package blob + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "slices" + "strings" +) + +var ErrInvalidDigest = errors.New("invalid digest") + +// Digest is a blob identifier that is the SHA-256 hash of a blob's content. +// +// It is comparable and can be used as a map key. +type Digest struct { + sum [32]byte +} + +// ParseDigest parses a digest from a string. If the string is not a valid +// digest, a call to the returned digest's IsValid method will return false. +// +// The input string may be in one of two forms: +// +// - ("sha256-"), where is a 64-character hexadecimal string. +// - ("sha256:"), where is a 64-character hexadecimal string. +// +// The [Digest.String] method will return the canonical form of the +// digest, "sha256:". +func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) { + s := string(v) + i := strings.IndexAny(s, ":-") + var zero Digest + if i < 0 { + return zero, ErrInvalidDigest + } + + prefix, sum := s[:i], s[i+1:] + if prefix != "sha256" || len(sum) != 64 { + return zero, ErrInvalidDigest + } + + var d Digest + _, err := hex.Decode(d.sum[:], []byte(sum)) + if err != nil { + return zero, ErrInvalidDigest + } + return d, nil +} + +func DigestFromBytes[S ~[]byte | ~string](v S) Digest { + return Digest{sha256.Sum256([]byte(v))} +} + +// String returns the string representation of the digest in the conventional +// form "sha256:". +func (d Digest) String() string { + return fmt.Sprintf("sha256:%x", d.sum[:]) +} + +func (d Digest) Short() string { + return fmt.Sprintf("%x", d.sum[:4]) +} + +func (d Digest) Compare(other Digest) int { + return slices.Compare(d.sum[:], other.sum[:]) +} + +// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash +// of some content. +func (d Digest) IsValid() bool { + return d != (Digest{}) +} + +// MarshalText implements the encoding.TextMarshaler interface. It returns an +// error if [Digest.IsValid] returns false. +func (d Digest) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface, and only +// works for a zero digest. If [Digest.IsValid] returns true, it returns an +// error. +func (d *Digest) UnmarshalText(text []byte) error { + if *d != (Digest{}) { + return errors.New("digest: illegal UnmarshalText on valid digest") + } + v, err := ParseDigest(string(text)) + if err != nil { + return err + } + *d = v + return nil +} diff --git a/server/internal/cache/blob/digest_test.go b/server/internal/cache/blob/digest_test.go new file mode 100644 index 000000000..c96ad383b --- /dev/null +++ b/server/internal/cache/blob/digest_test.go @@ -0,0 +1,63 @@ +package blob + +import ( + "encoding/json" + "testing" +) + +func TestParseDigest(t *testing.T) { + cases := []struct { + in string + valid bool + }{ + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true}, + + // too short + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false}, + + // too long + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false}, + + // invalid prefix + {"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + {"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + {"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + + // invalid hex + {"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false}, + {"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false}, + } + + for _, tt := range cases { + got, err := ParseDigest(tt.in) + if tt.valid && err != nil { + t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err) + } + want := "sha256:" + tt.in[7:] + if tt.valid && got.String() != want { + t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want) + } + } +} + +func TestDigestMarshalText(t *testing.T) { + const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"` + var d Digest + if err := json.Unmarshal([]byte(s), &d); err != nil { + t.Errorf("json.Unmarshal: %v", err) + } + out, err := json.Marshal(d) + if err != nil { + t.Errorf("json.Marshal: %v", err) + } + want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"` + if string(out) != want { + t.Errorf("json.Marshal: got %s, want %s", out, want) + } + if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil { + t.Errorf("json.Unmarshal: expected error") + } +} diff --git a/server/internal/chunks/chunks.go b/server/internal/chunks/chunks.go new file mode 100644 index 000000000..7eb7a6c17 --- /dev/null +++ b/server/internal/chunks/chunks.go @@ -0,0 +1,78 @@ +package chunks + +import ( + "fmt" + "iter" + "strconv" + "strings" +) + +type Chunk struct { + Start, End int64 +} + +func New(start, end int64) Chunk { + return Chunk{start, end} +} + +// ParseRange parses a string in the form "unit=range" where unit is a string +// and range is a string in the form "start-end". It returns the unit and the +// range as a Chunk. +func ParseRange(s string) (unit string, _ Chunk, _ error) { + unit, r, _ := strings.Cut(s, "=") + if r == "" { + return unit, Chunk{}, nil + } + c, err := Parse(r) + if err != nil { + return "", Chunk{}, err + } + return unit, c, err +} + +// Parse parses a string in the form "start-end" and returns the Chunk. +func Parse(s string) (Chunk, error) { + startStr, endStr, _ := strings.Cut(s, "-") + start, err := strconv.ParseInt(startStr, 10, 64) + if err != nil { + return Chunk{}, fmt.Errorf("invalid start: %v", err) + } + end, err := strconv.ParseInt(endStr, 10, 64) + if err != nil { + return Chunk{}, fmt.Errorf("invalid end: %v", err) + } + if start > end { + return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end) + } + return Chunk{start, end}, nil +} + +// Of returns a sequence of contiguous Chunks of size chunkSize that cover +// the range [0, size), in order. +func Of(size, chunkSize int64) iter.Seq[Chunk] { + return func(yield func(Chunk) bool) { + for start := int64(0); start < size; start += chunkSize { + end := min(start+chunkSize-1, size-1) + if !yield(Chunk{start, end}) { + break + } + } + } +} + +// Count returns the number of Chunks of size chunkSize needed to cover the +// range [0, size). +func Count(size, chunkSize int64) int64 { + return (size + chunkSize - 1) / chunkSize +} + +// Size returns end minus start plus one. +func (c Chunk) Size() int64 { + return c.End - c.Start + 1 +} + +// String returns the string representation of the Chunk in the form +// "{start}-{end}". +func (c Chunk) String() string { + return fmt.Sprintf("%d-%d", c.Start, c.End) +} diff --git a/server/internal/chunks/chunks_test.go b/server/internal/chunks/chunks_test.go new file mode 100644 index 000000000..c23e0de8e --- /dev/null +++ b/server/internal/chunks/chunks_test.go @@ -0,0 +1,65 @@ +package chunks + +import ( + "slices" + "testing" +) + +func TestOf(t *testing.T) { + cases := []struct { + total int64 + chunkSize int64 + want []Chunk + }{ + {0, 1, nil}, + {1, 1, []Chunk{{0, 0}}}, + {1, 2, []Chunk{{0, 0}}}, + {2, 1, []Chunk{{0, 0}, {1, 1}}}, + {10, 9, []Chunk{{0, 8}, {9, 9}}}, + } + + for _, tt := range cases { + got := slices.Collect(Of(tt.total, tt.chunkSize)) + if !slices.Equal(got, tt.want) { + t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want) + } + } +} + +func TestSize(t *testing.T) { + cases := []struct { + c Chunk + want int64 + }{ + {Chunk{0, 0}, 1}, + {Chunk{0, 1}, 2}, + {Chunk{3, 4}, 2}, + } + + for _, tt := range cases { + got := tt.c.Size() + if got != tt.want { + t.Errorf("%v: got %d; want %d", tt.c, got, tt.want) + } + } +} + +func TestCount(t *testing.T) { + cases := []struct { + total int64 + chunkSize int64 + want int64 + }{ + {0, 1, 0}, + {1, 1, 1}, + {1, 2, 1}, + {2, 1, 2}, + {10, 9, 2}, + } + for _, tt := range cases { + got := Count(tt.total, tt.chunkSize) + if got != tt.want { + t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want) + } + } +} diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go new file mode 100644 index 000000000..136122721 --- /dev/null +++ b/server/internal/client/ollama/registry.go @@ -0,0 +1,802 @@ +// Package ollama provides a client for interacting with an Ollama registry +// which pushes and pulls model manifests and layers as defined by the +// [ollama.com/manifest]. +package ollama + +import ( + "bufio" + "bytes" + "cmp" + "context" + "crypto" + "crypto/ed25519" + "crypto/sha256" + "crypto/tls" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/chunks" + "github.com/ollama/ollama/server/internal/internal/backoff" + "github.com/ollama/ollama/server/internal/internal/names" + "github.com/ollama/ollama/server/internal/internal/syncs" + + _ "embed" +) + +// Errors +var ( + // ErrManifestNotFound is returned when a manifest is not found in the + // cache or registry. + ErrManifestNotFound = errors.New("manifest not found") + + // ErrManifestInvalid is returned when a manifest found in a local or + // remote cache is invalid. + ErrManifestInvalid = errors.New("invalid manifest") + + // ErrMissingModel is returned when the model part of a name is missing + // or invalid. + ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + + // ErrCached is passed to [Trace.PushUpdate] when a layer already + // exists. It is a non-fatal error and is never returned by [Registry.Push]. + ErrCached = errors.New("cached") +) + +// Defaults +const ( + // DefaultChunkingThreshold is the threshold at which a layer should be + // split up into chunks when downloading. + DefaultChunkingThreshold = 128 << 20 + + // DefaultMaxChunkSize is the default maximum size of a chunk to + // download. It is configured based on benchmarks and aims to strike a + // balance between download speed and memory usage. + DefaultMaxChunkSize = 8 << 20 +) + +// DefaultCache returns a new disk cache for storing models. If the +// OLLAMA_MODELS environment variable is set, it uses that directory; +// otherwise, it uses $HOME/.ollama/models. +func DefaultCache() (*blob.DiskCache, error) { + dir := os.Getenv("OLLAMA_MODELS") + if dir == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + dir = filepath.Join(home, ".ollama", "models") + } + return blob.Open(dir) +} + +// Error is the standard error returned by Ollama APIs. +type Error struct { + Status int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` +} + +func (e *Error) Error() string { + return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *Error) UnmarshalJSON(b []byte) error { + type E Error + var v struct{ Errors []E } + if err := json.Unmarshal(b, &v); err != nil { + return err + } + if len(v.Errors) == 0 { + return fmt.Errorf("no messages in error response: %s", string(b)) + } + *e = Error(v.Errors[0]) // our registry only returns one error. + return nil +} + +// TODO(bmizerany): make configurable on [Registry] +var defaultName = func() names.Name { + n := names.Parse("ollama.com/library/_:latest") + if !n.IsFullyQualified() { + panic("default name is not fully qualified") + } + return n +}() + +// Registry is a client for performing push and pull operations against an +// Ollama registry. +type Registry struct { + // UserAgent is the User-Agent header to send with requests to the + // registry. If empty, the User-Agent is determined by HTTPClient. + UserAgent string + + // Key is the key used to authenticate with the registry. + // + // Currently, only Ed25519 keys are supported. + Key crypto.PrivateKey + + // HTTPClient is the HTTP client used to make requests to the registry. + // + // If nil, [http.DefaultClient] is used. + // + // As a quick note: If a Registry function that makes a call to a URL + // with the "https+insecure" scheme, the client will be cloned and the + // transport will be set to skip TLS verification, unless the client's + // Transport done not have a Clone method with the same signature as + // [http.Transport.Clone], which case, the call will fail. + HTTPClient *http.Client + + // MaxStreams is the maximum number of concurrent streams to use when + // pushing or pulling models. If zero, the number of streams is + // determined by [runtime.GOMAXPROCS]. + // + // Clients that want "unlimited" streams should set this to a large + // number. + MaxStreams int + + // ChunkingThreshold is the maximum size of a layer to download in a single + // request. If zero, [DefaultChunkingThreshold] is used. + ChunkingThreshold int64 + + // MaxChunkSize is the maximum size of a chunk to download. If zero, + // the default is [DefaultMaxChunkSize]. + // + // It is only used when a layer is larger than [MaxChunkingThreshold]. + MaxChunkSize int64 +} + +// RegistryFromEnv returns a new Registry configured from the environment. The +// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the +// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the +// system's temporary directory. +// +// It returns an error if any configuration in the environment is invalid. +func RegistryFromEnv() (*Registry, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519")) + if err != nil { + return nil, err + } + + var rc Registry + rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) + if err != nil { + return nil, err + } + maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS") + if maxStreams != "" { + var err error + rc.MaxStreams, err = strconv.Atoi(maxStreams) + if err != nil { + return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err) + } + } + return &rc, nil +} + +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + +// parseName parses name using [names.ParseExtended] and then merges the name with the +// default name, and checks that the name is fully qualified. If a digest is +// present, it parse and returns it with the other fields as their zero values. +// +// It returns an error if the name is not fully qualified, or if the digest, if +// any, is invalid. +// +// The scheme is returned as provided by [names.ParseExtended]. +func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) { + scheme, n, ds := names.ParseExtended(s) + n = names.Merge(n, defaultName) + if ds != "" { + // Digest is present. Validate it. + d, err = blob.ParseDigest(ds) + if err != nil { + return "", names.Name{}, blob.Digest{}, err + } + } + + // The name check is deferred until after the digest check because we + // say that digests take precedence over names, and so should there + // errors when being parsed. + if !n.IsFullyQualified() { + return "", names.Name{}, blob.Digest{}, ErrNameInvalid + } + + scheme = cmp.Or(scheme, "https") + return scheme, n, d, nil +} + +func (r *Registry) maxStreams() int { + n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) + + // Large downloads require a writter stream, so ensure we have at least + // two streams to avoid a deadlock. + return max(n, 2) +} + +func (r *Registry) maxChunkingThreshold() int64 { + return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold) +} + +// chunkSizeFor returns the chunk size for a layer of the given size. If the +// size is less than or equal to the max chunking threshold, the size is +// returned; otherwise, the max chunk size is returned. +func (r *Registry) maxChunkSize() int64 { + return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) +} + +// Push pushes the model with the name in the cache to the remote registry. +func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { + if p == nil { + p = &PushParams{} + } + + m, err := ResolveLocal(c, cmp.Or(p.From, name)) + if err != nil { + return err + } + + // Before much else happens, check layers at not null, the blobs exist, + // and the sizes match. This prevents long uploads followed by + // disappointment. + for _, l := range m.Layers { + if l == nil { + return fmt.Errorf("%w: null layer", ErrManifestInvalid) + } + info, err := c.Get(l.Digest) + if err != nil { + return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err) + } + if info.Size != l.Size { + return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size) + } + } + + t := traceFromContext(ctx) + + scheme, n, _, err := parseName(name) + if err != nil { + // This should never happen since ResolveLocal should have + // already validated the name. + panic(err) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var g errgroup.Group + g.SetLimit(r.maxStreams()) + for _, l := range m.Layers { + var progress atomic.Int64 + g.Go(func() (err error) { + defer func() { t.update(l, progress.Load(), err) }() + + t.update(l, 0, nil) + + startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + l.Digest, + ) + res, err := r.doOK(ctx, "POST", startURL, nil) + if err != nil { + return err + } + res.Body.Close() + + f, err := os.Open(c.GetFile(l.Digest)) + if err != nil { + return err + } + defer f.Close() + + uploadURL := res.Header.Get("Location") + if uploadURL == "" { + t.update(l, l.Size, ErrCached) + return nil + } + + req, err := r.newRequest(ctx, "PUT", uploadURL, f) + if err != nil { + return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err) + } + req.ContentLength = l.Size + + res, err = doOK(r.client(), req) + if err == nil { + res.Body.Close() + } + return err + }) + } + + if err := g.Wait(); err != nil { + return err + } + + // Commit + path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + n.Tag(), + ) + res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) + if err == nil { + res.Body.Close() + } + // TODO(bmizerany): add a "commit" trace event + return err +} + +func canRetry(err error) bool { + var re *Error + if !errors.As(err, &re) { + return false + } + return re.Status >= 500 +} + +// Pull pulls the model with the given name from the remote registry into the +// cache. +// +// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in +// chunks of the specified size, and then reassembled and verified. This is +// typically slower than splitting the model up across layers, and is mostly +// utilized for layers of type equal to "application/vnd.ollama.image". +func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { + scheme, n, _, err := parseName(name) + if err != nil { + return err + } + + m, err := r.Resolve(ctx, name) + if err != nil { + return err + } + if len(m.Layers) == 0 { + return fmt.Errorf("%w: no layers", ErrManifestInvalid) + } + + exists := func(l *Layer) bool { + info, err := c.Get(l.Digest) + return err == nil && info.Size == l.Size + } + + t := traceFromContext(ctx) + + var g errgroup.Group + g.SetLimit(r.maxStreams()) + + for _, l := range m.Layers { + if exists(l) { + t.update(l, l.Size, ErrCached) + continue + } + + blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest) + req, err := r.newRequest(ctx, "GET", blobURL, nil) + if err != nil { + t.update(l, 0, err) + continue + } + + t.update(l, 0, nil) + + if l.Size <= r.maxChunkingThreshold() { + g.Go(func() error { + res, err := doOK(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + err = c.Put(l.Digest, res.Body, l.Size) + if err == nil { + t.update(l, l.Size, nil) + } + return err + }) + } else { + q := syncs.NewRelayReader() + + g.Go(func() (err error) { + defer func() { q.CloseWithError(err) }() + return c.Put(l.Digest, q, l.Size) + }) + + var progress atomic.Int64 + + // We want to avoid extra round trips per chunk due to + // redirects from the registry to the blob store, so + // fire an initial request to get the final URL and + // then use that URL for the chunk requests. + req.Header.Set("Range", "bytes=0-0") + res, err := doOK(r.client(), req) + if err != nil { + return err + } + res.Body.Close() + req = res.Request.WithContext(req.Context()) + + streamNo := 0 + tws := make([]*bufio.Writer, r.maxStreams()-1) + for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { + ticket := q.Take() + bufIdx := streamNo % len(tws) + streamNo++ + g.Go(func() (err error) { + defer func() { + if err != nil { + q.CloseWithError(err) + } + ticket.Close() + t.update(l, progress.Load(), err) + }() + + for _, err := range backoff.Loop(ctx, 3*time.Second) { + if err != nil { + return err + } + + err := func() error { + req := req.Clone(req.Context()) + req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) + res, err := doOK(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + + tw := tws[bufIdx] + if tw == nil { + tw = bufio.NewWriterSize(nil, int(r.maxChunkSize())) + tws[bufIdx] = tw + } + tw.Reset(ticket) + defer tw.Reset(nil) // release ticket + + _, err = io.CopyN(tw, res.Body, chunk.Size()) + if err != nil { + return maybeUnexpectedEOF(err) + } + if err := tw.Flush(); err != nil { + return err + } + + total := progress.Add(chunk.Size()) + if total >= l.Size { + q.Close() + } + return nil + }() + if !canRetry(err) { + return err + } + } + return nil + }) + } + } + } + + if err := g.Wait(); err != nil { + return err + } + + // store the manifest blob + md := blob.DigestFromBytes(m.Data) + if err := blob.PutBytes(c, md, m.Data); err != nil { + return err + } + + // commit the manifest with a link + return c.Link(m.Name, md) +} + +// Manifest represents a [ollama.com/manifest]. +type Manifest struct { + Name string `json:"-"` // the canonical name of the model + Data []byte `json:"-"` // the raw data of the manifest + Layers []*Layer `json:"layers"` +} + +var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") + +// Layer returns the layer with the given +// digest, or nil if not found. +func (m *Manifest) Layer(d blob.Digest) *Layer { + for _, l := range m.Layers { + if l.Digest == d { + return l + } + } + return nil +} + +// MarshalJSON implements json.Marshaler. +// +// NOTE: It adds an empty config object to the manifest, which is required by +// the registry, but not used by the client. In the future, the config object +// will not be required by the registry and this will should be removed. +func (m Manifest) MarshalJSON() ([]byte, error) { + type M Manifest + v := struct { + M + + // This is ignored, mostly, by the registry But, if not + // present, it will cause an error to be returned during the + // last phase of the commit which expects it, but does nothing + // with it. This will be fixed in a future release of + // ollama.com. + Config *Layer `json:"config"` + }{ + M: M(m), + Config: &Layer{Digest: emptyDigest}, + } + return json.Marshal(v) +} + +// unmarshalManifest unmarshals the data into a manifest, and sets the name +// field to the string representation of the name. +// +// It panics if the name is not fully qualified. Callers should ensure the name +// is fully qualified before calling this function. +func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) { + if !n.IsFullyQualified() { + panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String())) + } + var m Manifest + if err := json.Unmarshal(data, &m); err != nil { + return nil, err + } + m.Name = n.String() + m.Data = data + return &m, nil +} + +// Layer is a layer in a model. +type Layer struct { + Digest blob.Digest `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +// ResolveLocal resolves a name to a Manifest in the local cache. The name is +// parsed using [names.ParseExtended] but the scheme is ignored. +func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { + _, n, d, err := parseName(name) + if err != nil { + return nil, err + } + if !d.IsValid() { + d, err = c.Resolve(n.String()) + if err != nil { + return nil, err + } + } + data, err := os.ReadFile(c.GetFile(d)) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) + } + return nil, err + } + m, err := unmarshalManifest(n, data) + if err != nil { + return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err)) + } + return m, nil +} + +// Resolve resolves a name to a Manifest in the remote registry. +func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { + scheme, n, d, err := parseName(name) + if err != nil { + return nil, err + } + + manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag()) + if d.IsValid() { + manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) + } + + res, err := r.doOK(ctx, "GET", manifestURL, nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + // TODO(bmizerany): return digest here + m, err := unmarshalManifest(n, data) + if err != nil { + return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err)) + } + return m, nil +} + +func (r *Registry) client() *http.Client { + if r.HTTPClient != nil { + return r.HTTPClient + } + return http.DefaultClient +} + +// newRequest constructs a new request, ready to use, with the given method, +// url, and body, presigned with client Key and UserAgent. +func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + if r.UserAgent != "" { + req.Header.Set("User-Agent", r.UserAgent) + } + if r.Key != nil { + token, err := makeAuthToken(r.Key) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + } + return req, nil +} + +// doOK makes a request with the given client and request, and returns the +// response if the status code is 200. If the status code is not 200, an Error +// is parsed from the response body and returned. If any other error occurs, it +// is returned. +func doOK(c *http.Client, r *http.Request) (*http.Response, error) { + if r.URL.Scheme == "https+insecure" { + // TODO(bmizerany): clone client.Transport, set + // InsecureSkipVerify, etc. + + type cloner interface { + Clone() *http.Transport + } + + // Attempt to configure the transport to skip TLS verification + // if we can clone it, otherwise fall through and let the http + // client complain and the scheme being invalid. + x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner) + if ok { + tr := x.Clone() + tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{}) + tr.TLSClientConfig.InsecureSkipVerify = true + + cc := *c // shallow copy + cc.Transport = tr + c = &cc + + r = r.Clone(r.Context()) + r.URL.Scheme = "https" + + // fall through + } + } + + res, err := c.Do(r) + if err != nil { + return nil, err + } + if res.StatusCode/100 != 2 { + out, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var re Error + if err := json.Unmarshal(out, &re); err != nil { + // Use the raw body if we can't parse it as an error object. + re.Message = string(out) + } + re.Status = res.StatusCode + return nil, &re + } + return res, nil +} + +// doOK is a convenience method for making a request with newRequest and +// passing it to doOK with r.client(). +func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + req, err := r.newRequest(ctx, method, path, body) + if err != nil { + return nil, err + } + return doOK(r.client(), req) +} + +// makeAuthToken creates an Ollama auth token for the given private key. +// +// NOTE: This format is OLD, overly complex, and should be replaced. We're +// inheriting it from the original Ollama client and ollama.com +// implementations, so we need to support it for now. +func makeAuthToken(key crypto.PrivateKey) (string, error) { + privKey, _ := key.(*ed25519.PrivateKey) + if privKey == nil { + return "", fmt.Errorf("unsupported private key type: %T", key) + } + + url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix()) + // Part 1: the checkData (e.g. the URL with a timestamp) + + // Part 2: the public key + pubKeyShort, err := func() ([]byte, error) { + sshPubKey, err := ssh.NewPublicKey(privKey.Public()) + if err != nil { + return nil, err + } + pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey)) + if len(pubKeyParts) < 2 { + return nil, fmt.Errorf("malformed public key: %q", pubKeyParts) + } + pubKeyShort := pubKeyParts[1] + return pubKeyShort, nil + }() + if err != nil { + return "", err + } + + // Part 3: the signature + sig := ed25519.Sign(*privKey, []byte(checkData(url))) + + // Assemble the token: :: + var b strings.Builder + io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url))) + b.WriteByte(':') + b.Write(pubKeyShort) + b.WriteByte(':') + io.WriteString(&b, base64.StdEncoding.EncodeToString(sig)) + + return b.String(), nil +} + +// The original spec for Ollama tokens was to use the SHA256 of the zero +// string as part of the signature. I'm not sure why that was, but we still +// need it to verify the signature. +var zeroSum = func() string { + sha256sum := sha256.Sum256(nil) + x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))) + return x +}() + +// checkData takes a URL and creates the original string format of the +// data signature that is used by the ollama client to sign requests +func checkData(url string) string { + return fmt.Sprintf("GET,%s,%s", url, zeroSum) +} + +func maybeUnexpectedEOF(err error) error { + if errors.Is(err, io.EOF) { + return io.ErrUnexpectedEOF + } + return err +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go new file mode 100644 index 000000000..d8f2a4077 --- /dev/null +++ b/server/internal/client/ollama/registry_test.go @@ -0,0 +1,656 @@ +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "math/rand/v2" + "net/http" + "net/http/httptest" + "os" + "path" + "reflect" + "slices" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/chunks" + "github.com/ollama/ollama/server/internal/internal/testutil" +) + +func TestManifestMarshalJSON(t *testing.T) { + // All manifests should contain an "empty" config object. + var m Manifest + data, err := json.Marshal(m) + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) { + t.Error("expected manifest to contain empty config") + t.Fatalf("got:\n%s", string(data)) + } +} + +func link(c *blob.DiskCache, name string, manifest string) { + _, n, _, err := parseName(name) + if err != nil { + panic(err) + } + d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest))) + if err != nil { + panic(err) + } + if err := c.Link(n.String(), d); err != nil { + panic(err) + } +} + +var errRoundTrip = errors.New("forced roundtrip error") + +type recordRoundTripper http.HandlerFunc + +func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + w := httptest.NewRecorder() + rr(w, req) + if w.Code == 499 { + return nil, errRoundTrip + } + resp := w.Result() + // For some reason, Response.Request is not set by httptest.NewRecorder, so we + // set it manually. + resp.Request = req + return w.Result(), nil +} + +// newClient constructs a cache with predefined manifests for testing. The manifests are: +// +// empty: no data +// zero: no layers +// single: one layer with the contents "exists" +// multiple: two layers with the contents "exists" and "here" +// notfound: a layer that does not exist in the cache +// null: one null layer (e.g. [null]) +// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) +// invalid: a layer with invalid JSON data +// +// Tests that want to ensure the client does not communicate with the upstream +// registry should pass a nil handler, which will cause a panic if +// communication is attempted. +// +// To simulate a network error, pass a handler that returns a 499 status code. +func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { + t.Helper() + c, err := blob.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + mklayer := func(data string) *Layer { + return &Layer{ + Digest: importBytes(t, c, data), + Size: int64(len(data)), + } + } + + commit := func(name string, layers ...*Layer) { + t.Helper() + data, err := json.Marshal(&Manifest{Layers: layers}) + if err != nil { + t.Fatal(err) + } + link(c, name, string(data)) + } + + link(c, "empty", "") + commit("zero") + commit("single", mklayer("exists")) + commit("multiple", mklayer("exists"), mklayer("present")) + commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))}) + commit("null", nil) + commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) + link(c, "invalid", "!!!!!") + + rc := &Registry{ + HTTPClient: &http.Client{ + Transport: recordRoundTripper(h), + }, + } + return rc, c +} + +func okHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func checkErrCode(t *testing.T, err error, status int, code string) { + t.Helper() + var e *Error + if !errors.As(err, &e) || e.Status != status || e.Code != code { + t.Errorf("err = %v; want %v %v", err, status, code) + } +} + +func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { + d, err := c.Import(strings.NewReader(data), int64(len(data))) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestRegistryPushInvalidNames(t *testing.T) { + rc, c := newClient(t, nil) + + cases := []struct { + name string + err error + }{ + {"", ErrNameInvalid}, + {"@", ErrNameInvalid}, + {"@x", blob.ErrInvalidDigest}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + // Create a new registry and push a new image. + err := rc.Push(t.Context(), c, tt.name, nil) + if !errors.Is(err, tt.err) { + t.Errorf("err = %v; want %v", err, tt.err) + } + }) + } +} + +func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { + t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} + return WithTrace(ctx, t), t +} + +func TestPushZero(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "empty", nil) + if !errors.Is(err, ErrManifestInvalid) { + t.Errorf("err = %v; want %v", err, ErrManifestInvalid) + } +} + +func TestPushSingle(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "single", nil) + testutil.Check(t, err) +} + +func TestPushMultiple(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "multiple", nil) + testutil.Check(t, err) +} + +func TestPushNotFound(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + t.Errorf("unexpected request: %v", r) + }) + err := rc.Push(t.Context(), c, "notfound", nil) + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("err = %v; want %v", err, fs.ErrNotExist) + } +} + +func TestPushNullLayer(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Push(t.Context(), c, "null", nil) + if err == nil || !strings.Contains(err.Error(), "invalid manifest") { + t.Errorf("err = %v; want invalid manifest", err) + } +} + +func TestPushSizeMismatch(t *testing.T) { + rc, c := newClient(t, nil) + ctx, _ := withTraceUnexpected(t.Context()) + got := rc.Push(ctx, c, "sizemismatch", nil) + if got == nil || !strings.Contains(got.Error(), "size mismatch") { + t.Errorf("err = %v; want size mismatch", got) + } +} + +func TestPushInvalid(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Push(t.Context(), c, "invalid", nil) + if err == nil || !strings.Contains(err.Error(), "invalid manifest") { + t.Errorf("err = %v; want invalid manifest", err) + } +} + +func TestPushExistsAtRemote(t *testing.T) { + var pushed bool + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/uploads/") { + if !pushed { + // First push. Return an uploadURL. + pushed = true + w.Header().Set("Location", "http://blob.store/blobs/123") + return + } + w.WriteHeader(http.StatusAccepted) + return + } + + io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + }) + + rc.MaxStreams = 1 // prevent concurrent uploads + + var errs []error + ctx := WithTrace(t.Context(), &Trace{ + Update: func(_ *Layer, n int64, err error) { + // uploading one at a time so no need to lock + errs = append(errs, err) + }, + }) + + check := testutil.Checker(t) + + err := rc.Push(ctx, c, "single", nil) + check(err) + + if !errors.Is(errors.Join(errs...), nil) { + t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) + } + + err = rc.Push(ctx, c, "single", nil) + check(err) +} + +func TestPushRemoteError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + w.WriteHeader(500) + io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) + return + } + }) + got := rc.Push(t.Context(), c, "single", nil) + checkErrCode(t, got, 500, "blob_error") +} + +func TestPushLocationError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", ":///x") + w.WriteHeader(http.StatusAccepted) + }) + got := rc.Push(t.Context(), c, "single", nil) + wantContains := "invalid upload URL" + if got == nil || !strings.Contains(got.Error(), wantContains) { + t.Errorf("err = %v; want to contain %v", got, wantContains) + } +} + +func TestPushUploadRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Host == "blob.store" { + w.WriteHeader(499) // force RoundTrip error on upload + return + } + w.Header().Set("Location", "http://blob.store/blobs/123") + }) + got := rc.Push(t.Context(), c, "single", nil) + if !errors.Is(got, errRoundTrip) { + t.Errorf("got = %v; want %v", got, errRoundTrip) + } +} + +func TestPushUploadFileOpenError(t *testing.T) { + rc, c := newClient(t, okHandler) + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, _ int64, err error) { + // Remove the file just before it is opened for upload, + // but after the initial Stat that happens before the + // upload starts + os.Remove(c.GetFile(l.Digest)) + }, + }) + got := rc.Push(ctx, c, "single", nil) + if !errors.Is(got, fs.ErrNotExist) { + t.Errorf("got = %v; want fs.ErrNotExist", got) + } +} + +func TestPushCommitRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + panic("unexpected") + } + w.WriteHeader(499) // force RoundTrip error + }) + err := rc.Push(t.Context(), c, "zero", nil) + if !errors.Is(err, errRoundTrip) { + t.Errorf("err = %v; want %v", err, errRoundTrip) + } +} + +func checkNotExist(t *testing.T, err error) { + t.Helper() + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v; want fs.ErrNotExist", err) + } +} + +func TestRegistryPullInvalidName(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Pull(t.Context(), c, "://") + if !errors.Is(err, ErrNameInvalid) { + t.Errorf("err = %v; want %v", err, ErrNameInvalid) + } +} + +func TestRegistryPullInvalidManifest(t *testing.T) { + cases := []string{ + "", + "null", + "!!!", + `{"layers":[]}`, + } + + for _, resp := range cases { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, resp) + }) + err := rc.Pull(t.Context(), c, "x") + if !errors.Is(err, ErrManifestInvalid) { + t.Errorf("err = %v; want invalid manifest", err) + } + } +} + +func TestRegistryPullNotCached(t *testing.T) { + check := testutil.Checker(t) + + var c *blob.DiskCache + var rc *Registry + + d := blob.DigestFromBytes("some data") + rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + io.WriteString(w, "some data") + return + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d) + }) + + // Confirm that the layer does not exist locally + _, err := ResolveLocal(c, "model") + checkNotExist(t, err) + + _, err = c.Get(d) + checkNotExist(t, err) + + err = rc.Pull(t.Context(), c, "model") + check(err) + + mw, err := rc.Resolve(t.Context(), "model") + check(err) + mg, err := ResolveLocal(c, "model") + check(err) + if !reflect.DeepEqual(mw, mg) { + t.Errorf("mw = %v; mg = %v", mw, mg) + } + + // Confirm successful download + info, err := c.Get(d) + check(err) + if info.Digest != d { + t.Errorf("info.Digest = %v; want %v", info.Digest, d) + } + if info.Size != 9 { + t.Errorf("info.Size = %v; want %v", info.Size, 9) + } + + data, err := os.ReadFile(c.GetFile(d)) + check(err) + if string(data) != "some data" { + t.Errorf("data = %q; want %q", data, "exists") + } +} + +func TestRegistryPullCached(t *testing.T) { + cached := blob.DigestFromBytes("exists") + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + w.WriteHeader(499) // should not be called + return + } + if strings.Contains(r.URL.Path, "/manifests/") { + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached) + } + }) + + var errs []error + var reads []int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(d *Layer, n int64, err error) { + t.Logf("update %v %d %v", d, n, err) + reads = append(reads, n) + errs = append(errs, err) + }, + }) + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + err := rc.Pull(ctx, c, "single") + testutil.Check(t, err) + + want := []int64{6} + if !errors.Is(errors.Join(errs...), ErrCached) { + t.Errorf("errs = %v; want %v", errs, ErrCached) + } + if !slices.Equal(reads, want) { + t.Errorf("pairs = %v; want %v", reads, want) + } +} + +func TestRegistryPullManifestNotFound(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + err := rc.Pull(t.Context(), c, "notfound") + checkErrCode(t, err, 404, "") +} + +func TestRegistryPullResolveRemoteError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) + }) + err := rc.Pull(t.Context(), c, "single") + checkErrCode(t, err, 500, "an_error") +} + +func TestRegistryPullResolveRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/manifests/") { + w.WriteHeader(499) // force RoundTrip error + return + } + }) + err := rc.Pull(t.Context(), c, "single") + if !errors.Is(err, errRoundTrip) { + t.Errorf("err = %v; want %v", err, errRoundTrip) + } +} + +// TestRegistryPullMixedCachedNotCached tests that cached layers do not +// interfere with pulling layers that are not cached +func TestRegistryPullMixedCachedNotCached(t *testing.T) { + x := blob.DigestFromBytes("xxxxxx") + e := blob.DigestFromBytes("exists") + y := blob.DigestFromBytes("yyyyyy") + + for i := range 10 { + t.Logf("iteration %d", i) + + digests := []blob.Digest{x, e, y} + + rand.Shuffle(len(digests), func(i, j int) { + digests[i], digests[j] = digests[j], digests[i] + }) + + manifest := fmt.Sprintf(`{ + "layers": [ + {"digest":"%s","size":6}, + {"digest":"%s","size":6}, + {"digest":"%s","size":6} + ] + }`, digests[0], digests[1], digests[2]) + + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + switch path.Base(r.URL.Path) { + case "latest": + io.WriteString(w, manifest) + case x.String(): + io.WriteString(w, "xxxxxx") + case e.String(): + io.WriteString(w, "exists") + case y.String(): + io.WriteString(w, "yyyyyy") + default: + panic(fmt.Sprintf("unexpected request: %v", r)) + } + }) + + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Logf("update %v %d %v", l, n, err) + }, + }) + + // Check that we pull all layers that we can. + + err := rc.Pull(ctx, c, "mixed") + if err != nil { + t.Fatal(err) + } + + for _, d := range digests { + info, err := c.Get(d) + if err != nil { + t.Fatalf("Get(%v): %v", d, err) + } + if info.Size != 6 { + t.Errorf("info.Size = %v; want %v", info.Size, 6) + } + } + } +} + +func TestRegistryPullChunking(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) + if r.URL.Host != "blob.store" { + // The production registry redirects to the blob store. + http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound) + return + } + if strings.Contains(r.URL.Path, "/blobs/") { + rng := r.Header.Get("Range") + if rng == "" { + http.Error(w, "missing range", http.StatusBadRequest) + return + } + _, c, err := chunks.ParseRange(r.Header.Get("Range")) + if err != nil { + panic(err) + } + io.WriteString(w, "remote"[c.Start:c.End+1]) + return + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote")) + }) + + // Force chunking by setting the threshold to less than the size of the + // layer. + rc.ChunkingThreshold = 3 + rc.MaxChunkSize = 3 + + var reads []int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(d *Layer, n int64, err error) { + if err != nil { + t.Errorf("update %v %d %v", d, n, err) + } + reads = append(reads, n) + }, + }) + + err := rc.Pull(ctx, c, "remote") + testutil.Check(t, err) + + want := []int64{0, 3, 6} + if !slices.Equal(reads, want) { + t.Errorf("reads = %v; want %v", reads, want) + } +} + +func TestRegistryResolveByDigest(t *testing.T) { + check := testutil.Checker(t) + + exists := blob.DigestFromBytes("exists") + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() { + w.WriteHeader(499) // should not hit manifest endpoint + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists) + }) + + _, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String()) + check(err) +} + +func TestInsecureSkipVerify(t *testing.T) { + exists := blob.DigestFromBytes("exists") + + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists) + })) + defer s.Close() + + const name = "ollama.com/library/insecure" + + var rc Registry + url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) + _, err := rc.Resolve(t.Context(), url) + if err == nil || !strings.Contains(err.Error(), "failed to verify") { + t.Errorf("err = %v; want cert verifiction failure", err) + } + + url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) + _, err = rc.Resolve(t.Context(), url) + testutil.Check(t, err) +} + +func TestCanRetry(t *testing.T) { + cases := []struct { + err error + want bool + }{ + {nil, false}, + {errors.New("x"), false}, + {ErrCached, false}, + {ErrManifestInvalid, false}, + {ErrNameInvalid, false}, + {&Error{Status: 100}, false}, + {&Error{Status: 500}, true}, + } + for _, tt := range cases { + if got := canRetry(tt.err); got != tt.want { + t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want) + } + } +} diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go new file mode 100644 index 000000000..8e53040ad --- /dev/null +++ b/server/internal/client/ollama/trace.go @@ -0,0 +1,48 @@ +package ollama + +import ( + "context" +) + +// Trace is a set of functions that are called to report progress during blob +// downloads and uploads. +type Trace struct { + // Update is called during [Registry.Push] and [Registry.Pull] to + // report the progress of blob uploads and downloads. + // + // It is called once at the beginning of the download with a zero n and + // then once per read operation with the number of bytes read so far, + // and an error if any. + // + // A function assigned must be safe for concurrent use. The function is + // called synchronously and so should not block or take long to run. + Update func(_ *Layer, n int64, _ error) +} + +func (t *Trace) update(l *Layer, n int64, err error) { + if t.Update != nil { + t.Update(l, n, err) + } +} + +type traceKey struct{} + +// WithTrace returns a context derived from ctx that uses t to report trace +// events. +func WithTrace(ctx context.Context, t *Trace) context.Context { + return context.WithValue(ctx, traceKey{}, t) +} + +var emptyTrace = &Trace{} + +// traceFromContext returns the Trace associated with ctx, or an empty Trace if +// none is found. +// +// It never returns nil. +func traceFromContext(ctx context.Context) *Trace { + t, _ := ctx.Value(traceKey{}).(*Trace) + if t == nil { + return emptyTrace + } + return t +} diff --git a/server/internal/cmd/opp/internal/safetensors/safetensors.go b/server/internal/cmd/opp/internal/safetensors/safetensors.go new file mode 100644 index 000000000..7f3e99798 --- /dev/null +++ b/server/internal/cmd/opp/internal/safetensors/safetensors.go @@ -0,0 +1,220 @@ +// safetensors provides a reader for the safetensor directories and files. +package safetensors + +import ( + "encoding/json" + "fmt" + "io" + "io/fs" + "iter" + "slices" + "strconv" + "strings" +) + +// Tensor represents a single tensor in a safetensors file. +// +// It's zero value is not valid. Use [Model.Tensors] to get valid tensors. +// +// It is not safe for use across multiple goroutines. +type Tensor struct { + name string + dataType string + shape []int64 + + fsys fs.FS + fname string // entry name in fsys + offset int64 + size int64 +} + +type Model struct { + fsys fs.FS +} + +func Read(fsys fs.FS) (*Model, error) { + return &Model{fsys: fsys}, nil +} + +func (m *Model) Tensors() iter.Seq2[*Tensor, error] { + return func(yield func(*Tensor, error) bool) { + entries, err := fs.Glob(m.fsys, "*.safetensors") + if err != nil { + yield(nil, err) + return + } + for _, e := range entries { + tt, err := m.readTensors(e) + if err != nil { + yield(nil, err) + return + } + for _, t := range tt { + if !yield(t, nil) { + return + } + } + } + } +} + +func (m *Model) readTensors(fname string) ([]*Tensor, error) { + f, err := m.fsys.Open(fname) + if err != nil { + return nil, err + } + defer f.Close() + + finfo, err := f.Stat() + if err != nil { + return nil, err + } + + headerSize, err := readInt64(f) + if err != nil { + return nil, err + } + + data := make([]byte, headerSize) + _, err = io.ReadFull(f, data) + if err != nil { + return nil, err + } + + var raws map[string]json.RawMessage + if err := json.Unmarshal(data, &raws); err != nil { + return nil, err + } + + // TODO(bmizerany): do something with metadata? This could be another + // header read if needed. We also need to figure out if the metadata is + // present in only one .safetensors file or if each file may have their + // own and if it needs to follow each tensor. Currently, I (bmizerany) + // am only seeing them show up with one entry for file type which is + // always "pt". + + tt := make([]*Tensor, 0, len(raws)) + for name, raw := range raws { + if !strings.HasPrefix(name, "model.layer") { + continue + } + var v struct { + DataType string `json:"dtype"` + Shape []int64 `json:"shape"` + Offsets []int64 `json:"data_offsets"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err) + } + if len(v.Offsets) != 2 { + return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets) + } + + // TODO(bmizerany): after collecting, validate all offests make + // tensors contiguous? + begin, end := v.Offsets[0], v.Offsets[1] + if err := checkBeginEnd(finfo.Size(), begin, end); err != nil { + return nil, err + } + + // TODO(bmizerany): just yield.. don't be silly and make a slice :) + tt = append(tt, &Tensor{ + name: name, + dataType: v.DataType, + shape: v.Shape, + fsys: m.fsys, + fname: fname, + offset: begin, + size: end - begin, + }) + } + return tt, nil +} + +func checkBeginEnd(size, begin, end int64) error { + if begin < 0 { + return fmt.Errorf("begin must not be negative: %d", begin) + } + if end < 0 { + return fmt.Errorf("end must not be negative: %d", end) + } + if end < begin { + return fmt.Errorf("end must be >= begin: %d < %d", end, begin) + } + if end > size { + return fmt.Errorf("end must be <= size: %d > %d", end, size) + } + return nil +} + +func readInt64(r io.Reader) (int64, error) { + var v uint64 + var buf [8]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return 0, err + } + for i := range buf { + v |= uint64(buf[i]) << (8 * i) + } + return int64(v), nil +} + +type Shape []int64 + +func (s Shape) String() string { + var b strings.Builder + b.WriteByte('[') + for i, v := range s { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(strconv.FormatInt(v, 10)) + } + b.WriteByte(']') + return b.String() +} + +func (t *Tensor) Name() string { return t.name } +func (t *Tensor) DataType() string { return t.dataType } +func (t *Tensor) Size() int64 { return t.size } +func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) } + +func (t *Tensor) Reader() (io.ReadCloser, error) { + f, err := t.fsys.Open(t.fname) + if err != nil { + return nil, err + } + r := newSectionReader(f, t.offset, t.size) + rc := struct { + io.Reader + io.Closer + }{r, f} + return rc, nil +} + +// newSectionReader returns a new io.Reader that reads from r starting at +// offset. It is a convenience function for creating a io.SectionReader when r +// may not be an io.ReaderAt. +// +// If r is already a ReaderAt, it is returned directly, otherwise if r is an +// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the +// beginning of the file. +// +// If r is an io.Seeker, +// or slow path. The slow path is used when r does not implement io.ReaderAt, +// in which case it must discard the data it reads. +func newSectionReader(r io.Reader, offset, n int64) io.Reader { + if r, ok := r.(io.ReaderAt); ok { + return io.NewSectionReader(r, offset, n) + } + if r, ok := r.(io.ReadSeeker); ok { + r.Seek(offset, io.SeekStart) + return io.LimitReader(r, n) + } + // Discard to offset and return a limited reader. + _, err := io.CopyN(io.Discard, r, offset) + if err != nil { + return nil + } + return io.LimitReader(r, n) +} diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go new file mode 100644 index 000000000..12199cf3e --- /dev/null +++ b/server/internal/cmd/opp/opp.go @@ -0,0 +1,366 @@ +package main + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "mime" + "net/http" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/client/ollama" + "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors" + "golang.org/x/sync/errgroup" +) + +var stdout io.Writer = os.Stdout + +const usage = `Opp is a tool for pushing and pulling Ollama models. + +Usage: + + opp [flags] + +Commands: + + push Upload a model to the Ollama server. + pull Download a model from the Ollama server. + import Import a model from a local safetensor directory. + +Examples: + + # Pull a model from the Ollama server. + opp pull library/llama3.2:latest + + # Push a model to the Ollama server. + opp push username/my_model:8b + + # Import a model from a local safetensor directory. + opp import /path/to/safetensor + +Envionment Variables: + + OLLAMA_MODELS + The directory where models are pushed and pulled from + (default ~/.ollama/models). +` + +func main() { + flag.Usage = func() { + fmt.Fprint(os.Stderr, usage) + } + flag.Parse() + + c, err := ollama.DefaultCache() + if err != nil { + log.Fatal(err) + } + + rc, err := ollama.RegistryFromEnv() + if err != nil { + log.Fatal(err) + } + + ctx := context.Background() + + err = func() error { + switch cmd := flag.Arg(0); cmd { + case "pull": + return cmdPull(ctx, rc, c) + case "push": + return cmdPush(ctx, rc, c) + case "import": + return cmdImport(ctx, c) + default: + if cmd == "" { + flag.Usage() + } else { + fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd) + } + os.Exit(2) + return errors.New("unreachable") + } + }() + if err != nil { + fmt.Fprintf(os.Stderr, "opp: %v\n", err) + os.Exit(1) + } +} + +func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { + model := flag.Arg(1) + if model == "" { + flag.Usage() + os.Exit(1) + } + + tr := http.DefaultTransport.(*http.Transport).Clone() + // TODO(bmizerany): configure transport? + rc.HTTPClient = &http.Client{Transport: tr} + + var mu sync.Mutex + p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded] + + var pb bytes.Buffer + printProgress := func() { + pb.Reset() + mu.Lock() + for d, s := range p { + // Write progress to a buffer first to avoid blocking + // on stdout while holding the lock. + stamp := time.Now().Format("2006/01/02 15:04:05") + fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0])) + if s[0] == s[1] { + delete(p, d) + } + } + mu.Unlock() + io.Copy(stdout, &pb) + } + + ctx = ollama.WithTrace(ctx, &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + if err != nil && !errors.Is(err, ollama.ErrCached) { + fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err) + return + } + + mu.Lock() + p[l.Digest] = [2]int64{l.Size, n} + mu.Unlock() + }, + }) + + errc := make(chan error) + go func() { + errc <- rc.Pull(ctx, c, model) + }() + + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-t.C: + printProgress() + case err := <-errc: + printProgress() + return err + } + } +} + +func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { + args := flag.Args()[1:] + flag := flag.NewFlagSet("push", flag.ExitOnError) + flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: opp push \n") + flag.PrintDefaults() + } + flag.Parse(args) + + model := flag.Arg(0) + if model == "" { + return fmt.Errorf("missing model argument") + } + + from := cmp.Or(*flagFrom, model) + m, err := ollama.ResolveLocal(c, from) + if err != nil { + return err + } + + ctx = ollama.WithTrace(ctx, &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + switch { + case errors.Is(err, ollama.ErrCached): + fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n) + case err != nil: + fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err) + case n == 0: + l := m.Layer(l.Digest) + mt, p, _ := mime.ParseMediaType(l.MediaType) + mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.") + switch mt { + case "tensor": + fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"]) + default: + fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType) + } + } + }, + }) + + return rc.Push(ctx, c, model, &ollama.PushParams{ + From: from, + }) +} + +type trackingReader struct { + io.Reader + n *atomic.Int64 +} + +func (r *trackingReader) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + r.n.Add(int64(n)) + return n, err +} + +func cmdImport(ctx context.Context, c *blob.DiskCache) error { + args := flag.Args()[1:] + flag := flag.NewFlagSet("import", flag.ExitOnError) + flagAs := flag.String("as", "", "Import using the provided name.") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: opp import \n") + flag.PrintDefaults() + } + flag.Parse(args) + + dir := cmp.Or(flag.Arg(0), ".") + fmt.Fprintf(os.Stderr, "Reading %s\n", dir) + + m, err := safetensors.Read(os.DirFS(dir)) + if err != nil { + return err + } + + var total int64 + var tt []*safetensors.Tensor + for t, err := range m.Tensors() { + if err != nil { + return err + } + tt = append(tt, t) + total += t.Size() + } + + var n atomic.Int64 + done := make(chan error) + go func() { + layers := make([]*ollama.Layer, len(tt)) + var g errgroup.Group + g.SetLimit(runtime.GOMAXPROCS(0)) + var ctxErr error + for i, t := range tt { + if ctx.Err() != nil { + // The context may cancel AFTER we exit the + // loop, and so if we use ctx.Err() after the + // loop we may report it as the error that + // broke the loop, when it was not. This can + // manifest as a false-negative, leading the + // user to think their import failed when it + // did not, so capture it if and only if we + // exit the loop because of a ctx.Err() and + // report it. + ctxErr = ctx.Err() + break + } + g.Go(func() (err error) { + rc, err := t.Reader() + if err != nil { + return err + } + defer rc.Close() + tr := &trackingReader{rc, &n} + d, err := c.Import(tr, t.Size()) + if err != nil { + return err + } + if err := rc.Close(); err != nil { + return err + } + + layers[i] = &ollama.Layer{ + Digest: d, + Size: t.Size(), + MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{ + "name": t.Name(), + "dtype": t.DataType(), + "shape": t.Shape().String(), + }), + } + + return nil + }) + } + + done <- func() error { + if err := errors.Join(g.Wait(), ctxErr); err != nil { + return err + } + m := &ollama.Manifest{Layers: layers} + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err + } + d := blob.DigestFromBytes(data) + err = blob.PutBytes(c, d, data) + if err != nil { + return err + } + return c.Link(*flagAs, d) + }() + }() + + fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir) + + csiHideCursor(stdout) + defer csiShowCursor(stdout) + + csiSavePos(stdout) + writeProgress := func() { + csiRestorePos(stdout) + nn := n.Load() + fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n", + formatNatural(nn), + formatNatural(total), + nn*100/total, + ansiClearToEndOfLine, + ) + } + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + writeProgress() + case err := <-done: + writeProgress() + return err + } + } +} + +func formatNatural(n int64) string { + switch { + case n < 1024: + return fmt.Sprintf("%d B", n) + case n < 1024*1024: + return fmt.Sprintf("%.1f KB", float64(n)/1024) + case n < 1024*1024*1024: + return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024)) + default: + return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024)) + } +} + +const ansiClearToEndOfLine = "\033[K" + +func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") } +func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") } +func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") } +func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") } diff --git a/server/internal/cmd/oppbench/oppbench.go b/server/internal/cmd/oppbench/oppbench.go new file mode 100644 index 000000000..7a5305947 --- /dev/null +++ b/server/internal/cmd/oppbench/oppbench.go @@ -0,0 +1,11 @@ +package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("Run as 'go test -bench=.' to run the benchmarks") + os.Exit(1) +} diff --git a/server/internal/cmd/oppbench/oppbench_test.go b/server/internal/cmd/oppbench/oppbench_test.go new file mode 100644 index 000000000..c71d6cded --- /dev/null +++ b/server/internal/cmd/oppbench/oppbench_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/chunks" + "golang.org/x/sync/errgroup" +) + +func BenchmarkDownload(b *testing.B) { + run := func(fileSize, chunkSize int64) { + name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize) + b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) }) + } + + run(100<<20, 8<<20) + run(100<<20, 16<<20) + run(100<<20, 32<<20) + run(100<<20, 64<<20) + run(100<<20, 128<<20) // 1 chunk +} + +func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error { + const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d" + req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil) + if err != nil { + return err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) + res, err := c.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + _, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read + return err +} + +var sleepTime atomic.Int64 + +func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) { + client := &http.Client{ + Transport: func() http.RoundTripper { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DisableKeepAlives = true + return tr + }(), + } + defer client.CloseIdleConnections() + + // warm up the client + run(context.Background(), client, chunks.New(0, 1<<20)) + + b.SetBytes(fileSize) + b.ReportAllocs() + + // Give our CDN a min to breathe between benchmarks. + time.Sleep(time.Duration(sleepTime.Swap(3))) + + for b.Loop() { + g, ctx := errgroup.WithContext(b.Context()) + g.SetLimit(runtime.GOMAXPROCS(0)) + for chunk := range chunks.Of(fileSize, chunkSize) { + g.Go(func() error { return run(ctx, client, chunk) }) + } + if err := g.Wait(); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite(b *testing.B) { + b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) }) +} + +func benchmarkWrite(b *testing.B, chunkSize int) { + b.ReportAllocs() + + dir := b.TempDir() + f, err := os.Create(filepath.Join(dir, "write-single")) + if err != nil { + b.Fatal(err) + } + defer f.Close() + + data := make([]byte, chunkSize) + b.SetBytes(int64(chunkSize)) + r := bytes.NewReader(data) + for b.Loop() { + r.Reset(data) + _, err := io.Copy(f, r) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go new file mode 100644 index 000000000..1f0634f7c --- /dev/null +++ b/server/internal/internal/backoff/backoff.go @@ -0,0 +1,48 @@ +package backoff + +import ( + "context" + "iter" + "math/rand/v2" + "time" +) + +func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { + var n int + return func(yield func(int, error) bool) { + var t *time.Timer + for { + if ctx.Err() != nil { + yield(n, ctx.Err()) + return + } + + if !yield(n, nil) { + return + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := time.Duration(n*n) * 10 * time.Millisecond + if d > maxBackoff { + d = maxBackoff + } + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + + if t == nil { + t = time.NewTimer(d) + } else { + t.Reset(d) + } + select { + case <-ctx.Done(): + t.Stop() + case <-t.C: + } + } + } +} diff --git a/server/internal/internal/backoff/backoff_synctest_test.go b/server/internal/internal/backoff/backoff_synctest_test.go new file mode 100644 index 000000000..cf17ce80a --- /dev/null +++ b/server/internal/internal/backoff/backoff_synctest_test.go @@ -0,0 +1,40 @@ +//go:build goexperiment.synctest + +package backoff + +import ( + "context" + "errors" + "testing" + "testing/synctest" + "time" +) + +func TestLoop(t *testing.T) { + synctest.Run(func() { + last := -1 + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + for n, err := range Loop(ctx, 100*time.Millisecond) { + if !errors.Is(err, ctx.Err()) { + t.Errorf("err = %v, want nil", err) + } + if err != nil { + break + } + if n != last+1 { + t.Errorf("n = %d, want %d", n, last+1) + } + last = n + if n > 5 { + cancel() + } + } + + if last != 6 { + t.Errorf("last = %d, want 6", last) + } + }) +} diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go new file mode 100644 index 000000000..bb8438a78 --- /dev/null +++ b/server/internal/internal/backoff/backoff_test.go @@ -0,0 +1,38 @@ +package backoff + +import ( + "context" + "testing" + "testing/synctest" + "time" +) + +func TestLoopAllocs(t *testing.T) { + for i := range 3 { + got := testing.AllocsPerRun(1000, func() { + for tick := range Loop(t.Context(), 1) { + if tick >= i { + break + } + } + }) + want := float64(0) + if i > 0 { + want = 3 // due to time.NewTimer + } + if got > want { + t.Errorf("[%d ticks]: allocs = %v, want 0", i, want) + } + } +} + +func BenchmarkLoop(b *testing.B) { + ctx := context.Background() + synctest.Run(func() { + for n := range Loop(ctx, 100*time.Millisecond) { + if n == b.N { + break + } + } + }) +} diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go new file mode 100644 index 000000000..361cce76f --- /dev/null +++ b/server/internal/internal/names/name.go @@ -0,0 +1,229 @@ +package names + +import ( + "cmp" + "fmt" + "strings" + + "github.com/ollama/ollama/server/internal/internal/stringsx" +) + +const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: + +type Name struct { + // Make incomparable to enfoce use of Compare / Equal for + // case-insensitive comparisons. + _ [0]func() + + h string + n string + m string + t string +} + +// Parse parses and assembles a Name from a name string. The +// format of a valid name string is: +// +// s: +// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } +// { host } "/" { namespace } "/" { model } ":" { tag } +// { host } "/" { namespace } "/" { model } "@" { digest } +// { host } "/" { namespace } "/" { model } +// { namespace } "/" { model } ":" { tag } "@" { digest } +// { namespace } "/" { model } ":" { tag } +// { namespace } "/" { model } "@" { digest } +// { namespace } "/" { model } +// { model } ":" { tag } "@" { digest } +// { model } ":" { tag } +// { model } "@" { digest } +// { model } +// "@" { digest } +// host: +// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* +// length: [1, 350] +// namespace: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" }* +// length: [1, 80] +// model: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// tag: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// digest: +// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* +// length: [1, 80] +// +// The name returned is not guaranteed to be valid. If it is not valid, the +// field values are left in an undefined state. Use [Name.IsValid] to check +// if the name is valid. +func Parse(s string) Name { + if len(s) > MaxNameLength { + return Name{} + } + + var n Name + var tail string + var c byte + for { + s, tail, c = cutLastAny(s, "/:") + switch c { + case ':': + n.t = tail + continue // look for model + case '/': + n.h, n.n, _ = cutLastAny(s, "/") + n.m = tail + return n + case 0: + n.m = tail + return n + } + } +} + +// ParseExtended parses and returns any scheme, Name, and digest from from s in +// the the form [scheme://][name][@digest]. All parts are optional. +// +// If the scheme is present, it must be followed by "://". The digest is +// prefixed by "@" and comes after the name. The name is parsed using [Parse]. +// +// The scheme and digest are stripped before the name is parsed by [Parse]. +// +// For convience, the scheme is never empty. If the scheme is not present, the +// returned scheme is "https". +// +// Examples: +// +// http://ollama.com/bmizerany/smol:latest@digest +// https://ollama.com/bmizerany/smol:latest +// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. +func ParseExtended(s string) (scheme string, _ Name, digest string) { + i := strings.Index(s, "://") + if i >= 0 { + scheme = s[:i] + s = s[i+3:] + } + i = strings.LastIndex(s, "@") + if i >= 0 { + digest = s[i+1:] + s = s[:i] + } + return scheme, Parse(s), digest +} + +func FormatExtended(scheme string, n Name, digest string) string { + var b strings.Builder + if scheme != "" { + b.WriteString(scheme) + b.WriteString("://") + } + b.WriteString(n.String()) + if digest != "" { + b.WriteByte('@') + b.WriteString(digest) + } + return b.String() +} + +// Merge merges two names into a single name. Non-empty host, namespace, and +// tag parts of a take precedence over fields in b. The model field is left as +// is. +// +// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check +// if the name is valid. +func Merge(a, b Name) Name { + a.h = cmp.Or(a.h, b.h) + a.n = cmp.Or(a.n, b.n) + a.t = cmp.Or(a.t, b.t) + return a +} + +// IsValid returns true if the name is valid. +func (n Name) IsValid() bool { + if n.h != "" && !isValidHost(n.h) { + return false + } + if n.n != "" && !isValidNamespace(n.n) { + return false + } + if n.m != "" && !isValidModel(n.m) { + return false + } + if n.t != "" && !isValidTag(n.t) { + return false + } + return true +} + +func (n Name) IsFullyQualified() bool { + return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" +} + +func isValidHost(_ string) bool { + return true // TODO: implement +} + +func isValidNamespace(_ string) bool { + return true // TODO: implement +} + +func isValidModel(_ string) bool { + return true // TODO: implement +} + +func isValidTag(_ string) bool { + return true // TODO: implement +} + +func (n Name) Host() string { return n.h } +func (n Name) Namespace() string { return n.n } +func (n Name) Model() string { return n.m } +func (n Name) Tag() string { return n.t } + +// Compare compares n and o case-insensitively. It returns 0 if n and o are +// equal, -1 if n sorts before o, and 1 if n sorts after o. +func (n Name) Compare(o Name) int { + return cmp.Or( + stringsx.CompareFold(n.h, o.h), + stringsx.CompareFold(n.n, o.n), + stringsx.CompareFold(n.m, o.m), + stringsx.CompareFold(n.t, o.t), + ) +} + +// String returns the fully qualified name in the format +// /:. +func (n Name) String() string { + var b strings.Builder + if n.h != "" { + b.WriteString(n.h) + b.WriteByte('/') + } + if n.n != "" { + b.WriteString(n.n) + b.WriteByte('/') + } + b.WriteString(n.m) + if n.t != "" { + b.WriteByte(':') + b.WriteString(n.t) + } + return b.String() +} + +func (n Name) GoString() string { + return fmt.Sprintf("", n.h, n.n, n.m, n.t) +} + +// cutLastAny is like strings.Cut but scans in reverse for the last character +// in chars. If no character is found, before is the empty string and after is +// s. The returned sep is the byte value of the character in chars if one was +// found; otherwise it is 0. +func cutLastAny(s, chars string) (before, after string, sep byte) { + i := strings.LastIndexAny(s, chars) + if i >= 0 { + return s[:i], s[i+1:], s[i] + } + return "", s, 0 +} diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go new file mode 100644 index 000000000..760fec5fa --- /dev/null +++ b/server/internal/internal/names/name_test.go @@ -0,0 +1,152 @@ +package names + +import ( + "strings" + "testing" +) + +func TestParseName(t *testing.T) { + cases := []struct { + in string + want Name + }{ + {"", Name{}}, + {"m:t", Name{m: "m", t: "t"}}, + {"m", Name{m: "m"}}, + {"/m", Name{m: "m"}}, + {"/n/m:t", Name{n: "n", m: "m", t: "t"}}, + {"n/m", Name{n: "n", m: "m"}}, + {"n/m:t", Name{n: "n", m: "m", t: "t"}}, + {"n/m", Name{n: "n", m: "m"}}, + {"n/m", Name{n: "n", m: "m"}}, + {strings.Repeat("m", MaxNameLength+1), Name{}}, + {"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}}, + {"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}}, + + // Invalids + // TODO: {"n:t/m:t", Name{}}, + // TODO: {"/h/n/m:t", Name{}}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got := Parse(tt.in) + if got.Compare(tt.want) != 0 { + t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestString(t *testing.T) { + cases := []string{ + "", + "m:t", + "m:t", + "m", + "n/m", + "n/m:t", + "n/m", + "n/m", + "h/n/m:t", + "ollama.com/library/_:latest", + + // Special cased to "round trip" without the leading slash. + "/m", + "/n/m:t", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + s = strings.TrimPrefix(s, "/") + if g := Parse(s).String(); g != s { + t.Errorf("parse(%q).String() = %q", s, g) + } + }) + } +} + +func TestParseExtended(t *testing.T) { + cases := []struct { + in string + + wantScheme string + wantName Name + wantDigest string + }{ + {"", "", Name{}, ""}, + {"m", "", Name{m: "m"}, ""}, + {"http://m", "http", Name{m: "m"}, ""}, + {"http+insecure://m", "http+insecure", Name{m: "m"}, ""}, + {"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + scheme, name, digest := ParseExtended(tt.in) + if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) + } + + // Round trip + if got := FormatExtended(scheme, name, digest); got != tt.in { + t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) + } + }) + } +} + +func TestMerge(t *testing.T) { + cases := []struct { + a, b string + want string + }{ + {"", "", ""}, + {"m", "", "m"}, + {"", "m", ""}, + {"x", "y", "x"}, + {"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"}, + {"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"}, + + {"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"}, + {"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"}, + } + for _, tt := range cases { + t.Run("", func(t *testing.T) { + a, b := Parse(tt.a), Parse(tt.b) + got := Merge(a, b) + if got.Compare(Parse(tt.want)) != 0 { + t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestParseStringRoundTrip(t *testing.T) { + cases := []string{ + "", + "m", + "m:t", + "n/m", + "n/m:t", + "n/m:t", + "n/m", + "n/m", + "h/n/m:t", + "ollama.com/library/_:latest", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + if got := Parse(s).String(); got != s { + t.Errorf("parse(%q).String() = %q", s, got) + } + }) + } +} + +var junkName Name + +func BenchmarkParseName(b *testing.B) { + b.ReportAllocs() + for range b.N { + junkName = Parse("h/n/m:t") + } +} diff --git a/server/internal/internal/stringsx/stringsx.go b/server/internal/internal/stringsx/stringsx.go new file mode 100644 index 000000000..6c7a8d20d --- /dev/null +++ b/server/internal/internal/stringsx/stringsx.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package stringsx provides additional string manipulation functions +// that aren't in the standard library's strings package or go4.org/mem. +package stringsx + +import ( + "unicode" + "unicode/utf8" +) + +// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b, +// like cmp.Compare, but case insensitively. +func CompareFold(a, b string) int { + // Track our position in both strings + ia, ib := 0, 0 + for ia < len(a) && ib < len(b) { + ra, wa := nextRuneLower(a[ia:]) + rb, wb := nextRuneLower(b[ib:]) + if ra < rb { + return -1 + } + if ra > rb { + return 1 + } + ia += wa + ib += wb + if wa == 0 || wb == 0 { + break + } + } + + // If we've reached here, one or both strings are exhausted + // The shorter string is "less than" if they match up to this point + switch { + case ia == len(a) && ib == len(b): + return 0 + case ia == len(a): + return -1 + default: + return 1 + } +} + +// nextRuneLower returns the next rune in the string, lowercased, along with its +// original (consumed) width in bytes. If the string is empty, it returns +// (utf8.RuneError, 0) +func nextRuneLower(s string) (r rune, width int) { + r, width = utf8.DecodeRuneInString(s) + return unicode.ToLower(r), width +} diff --git a/server/internal/internal/stringsx/stringsx_test.go b/server/internal/internal/stringsx/stringsx_test.go new file mode 100644 index 000000000..8575c0b27 --- /dev/null +++ b/server/internal/internal/stringsx/stringsx_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package stringsx + +import ( + "cmp" + "strings" + "testing" +) + +func TestCompareFold(t *testing.T) { + tests := []struct { + a, b string + }{ + // Basic ASCII cases + {"", ""}, + {"a", "a"}, + {"a", "A"}, + {"A", "a"}, + {"a", "b"}, + {"b", "a"}, + {"abc", "ABC"}, + {"ABC", "abc"}, + {"abc", "abd"}, + {"abd", "abc"}, + + // Length differences + {"abc", "ab"}, + {"ab", "abc"}, + + // Unicode cases + {"世界", "世界"}, + {"Hello世界", "hello世界"}, + {"世界Hello", "世界hello"}, + {"世界", "世界x"}, + {"世界x", "世界"}, + + // Special case folding examples + {"ß", "ss"}, // German sharp s + {"fi", "fi"}, // fi ligature + {"Σ", "σ"}, // Greek sigma + {"İ", "i\u0307"}, // Turkish dotted I + + // Mixed cases + {"HelloWorld", "helloworld"}, + {"HELLOWORLD", "helloworld"}, + {"helloworld", "HELLOWORLD"}, + {"HelloWorld", "helloworld"}, + {"helloworld", "HelloWorld"}, + + // Edge cases + {" ", " "}, + {"1", "1"}, + {"123", "123"}, + {"!@#", "!@#"}, + } + + wants := []int{} + for _, tt := range tests { + got := CompareFold(tt.a, tt.b) + want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b)) + if got != want { + t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want) + } + wants = append(wants, want) + } + + if n := testing.AllocsPerRun(1000, func() { + for i, tt := range tests { + if CompareFold(tt.a, tt.b) != wants[i] { + panic("unexpected") + } + } + }); n > 0 { + t.Errorf("allocs = %v; want 0", int(n)) + } +} diff --git a/server/internal/internal/syncs/line.go b/server/internal/internal/syncs/line.go new file mode 100644 index 000000000..021cd4c09 --- /dev/null +++ b/server/internal/internal/syncs/line.go @@ -0,0 +1,201 @@ +// Package syncs provides synchronization primitives. +package syncs + +import ( + "cmp" + "io" + "sync" +) + +var closedChan = func() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +}() + +// Ticket represents a ticket in a sequence of tickets. The zero value is +// invalid. Use [Line.Take] to get a valid ticket. +// +// A Ticket is not safe for concurrent use. +type Ticket struct { + ahead chan struct{} // ticket ahead of this one + ch chan struct{} +} + +// Ready returns a channel that is closed when the ticket before this one is +// done. +// +// It is incorrect to wait on Ready after the ticket is done. +func (t *Ticket) Ready() chan struct{} { + return cmp.Or(t.ahead, closedChan) +} + +// Done signals that this ticket is done and that the next ticket in line can +// proceed. +// +// The first call to [Done] unblocks the ticket after it, if any. Subsequent +// calls are no-ops. +func (t *Ticket) Done() { + if t.ch != nil { + close(t.ch) + } + t.ch = nil +} + +// Line is an ordered sequence of tickets waiting for their turn to proceed. +// +// To get a ticket use [Line.Take]. +// To signal that a ticket is done use [Ticket.Done]. +// To wait your turn use [Ticket.Ready]. +// +// A Line is not safe for concurrent use. +type Line struct { + last chan struct{} // last ticket in line +} + +func (q *Line) Take() *Ticket { + t := &Ticket{ + ahead: q.last, + ch: make(chan struct{}), + } + q.last = t.ch + return t +} + +// RelayReader implements an [io.WriterTo] that yields the passed +// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in +// the order they are taken. Each [io.WriteCloser] blocks until the previous +// one is closed, or a call to [RelayReader.CloseWithError] is made. +// +// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader. +// +// It is not safe for concurrent use. +type RelayReader struct { + line Line + t *Ticket + w io.Writer + n int64 + + mu sync.Mutex + err error // set by CloseWithError + closedCh chan struct{} // closed if err is set +} + +var ( + _ io.Closer = (*RelayReader)(nil) + _ io.WriterTo = (*RelayReader)(nil) + _ io.Reader = (*RelayReader)(nil) +) + +func NewRelayReader() *RelayReader { + var q RelayReader + q.closedCh = make(chan struct{}) + q.t = q.line.Take() + return &q +} + +// CloseWithError terminates the line, unblocking any writer waiting for its +// turn with the error, or [io.EOF] if err is nil. It is safe to call +// [CloseWithError] multiple times and across multiple goroutines. +// +// If the line is already closed, [CloseWithError] is a no-op. +// +// It never returns an error. +func (q *RelayReader) CloseWithError(err error) error { + q.mu.Lock() + defer q.mu.Unlock() + if q.err == nil { + q.err = cmp.Or(q.err, err, io.EOF) + close(q.closedCh) + } + return nil +} + +// Close closes the line. Any writer waiting for its turn will be unblocked +// with an [io.ErrClosedPipe] error. +// +// It never returns an error. +func (q *RelayReader) Close() error { + return q.CloseWithError(nil) +} + +func (q *RelayReader) closed() <-chan struct{} { + q.mu.Lock() + defer q.mu.Unlock() + return q.closedCh +} + +func (q *RelayReader) Read(p []byte) (int, error) { + panic("RelayReader.Read is for show only; use WriteTo") +} + +// WriteTo yields the writer w to the first writer in line and blocks until the +// first call to [Close]. +// +// It is safe to call [Take] concurrently with [WriteTo]. +func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) { + select { + case <-q.closed(): + return 0, io.ErrClosedPipe + default: + } + + // We have a destination writer; let the relay begin. + q.w = dst + q.t.Done() + <-q.closed() + return q.n, nil +} + +// Take returns a writer that will be passed to the next writer in line. +// +// It is not safe for use across multiple goroutines. +func (q *RelayReader) Take() io.WriteCloser { + return &relayWriter{q: q, t: q.line.Take()} +} + +type relayWriter struct { + q *RelayReader + t *Ticket + ready bool +} + +var _ io.StringWriter = (*relayWriter)(nil) + +// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the +// writer is ready. It returns io.ErrClosedPipe if the line is closed before +// the writer is ready. +func (w *relayWriter) Write(p []byte) (int, error) { + if !w.awaitTurn() { + return 0, w.q.err + } + n, err := w.q.w.Write(p) + w.q.n += int64(n) + return n, err +} + +func (w *relayWriter) WriteString(s string) (int, error) { + if !w.awaitTurn() { + return 0, w.q.err + } + return io.WriteString(w.q.w, s) +} + +// Close signals that the writer is done, unblocking the next writer in line. +func (w *relayWriter) Close() error { + w.t.Done() + return nil +} + +func (t *relayWriter) awaitTurn() (ok bool) { + if t.ready { + return true + } + select { + case <-t.t.Ready(): + t.ready = true + return true + case <-t.q.closed(): + return false + } +} diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go new file mode 100644 index 000000000..d52160260 --- /dev/null +++ b/server/internal/internal/syncs/line_test.go @@ -0,0 +1,65 @@ +package syncs + +import ( + "bytes" + "io" + "math/rand/v2" + "testing" + "testing/synctest" +) + +func TestPipelineReadWriterTo(t *testing.T) { + for range 10 { + synctest.Run(func() { + q := NewRelayReader() + + tickets := []struct { + io.WriteCloser + s string + }{ + {q.Take(), "you"}, + {q.Take(), " say hi,"}, + {q.Take(), " and "}, + {q.Take(), "I say "}, + {q.Take(), "hello"}, + } + + rand.Shuffle(len(tickets), func(i, j int) { + tickets[i], tickets[j] = tickets[j], tickets[i] + }) + + var g Group + for i, t := range tickets { + g.Go(func() { + defer t.Close() + if i%2 == 0 { + // Use [relayWriter.WriteString] + io.WriteString(t.WriteCloser, t.s) + } else { + t.Write([]byte(t.s)) + } + }) + } + + var got bytes.Buffer + var copyErr error // checked at end + g.Go(func() { + _, copyErr = io.Copy(&got, q) + }) + + synctest.Wait() + + q.Close() + g.Wait() + + if copyErr != nil { + t.Fatal(copyErr) + } + + want := "you say hi, and I say hello" + if got.String() != want { + t.Fatalf("got %q, want %q", got.String(), want) + } + }) + } +} diff --git a/server/internal/internal/syncs/syncs.go b/server/internal/internal/syncs/syncs.go new file mode 100644 index 000000000..8f1b1e078 --- /dev/null +++ b/server/internal/internal/syncs/syncs.go @@ -0,0 +1,41 @@ +package syncs + +import ( + "sync" + "sync/atomic" +) + +// Group is a [sync.WaitGroup] with a Go method. +type Group struct { + wg sync.WaitGroup + n atomic.Int64 +} + +func (g *Group) Go(f func()) { + g.wg.Add(1) + go func() { + g.n.Add(1) // Now we are running + defer func() { + g.wg.Done() + g.n.Add(-1) // Now we are done + }() + f() + }() +} + +// Running returns the number of goroutines that are currently running. +// +// If a call to [Running] returns zero, and a call to [Wait] is made without +// any calls to [Go], then [Wait] will return immediately. This is true even if +// a goroutine is started and finishes between the two calls. +// +// It is possible for [Running] to return non-zero and for [Wait] to return +// immediately. This can happen if the all running goroutines finish between +// the two calls. +func (g *Group) Running() int64 { + return g.n.Load() +} + +func (g *Group) Wait() { + g.wg.Wait() +} diff --git a/server/internal/internal/testutil/testutil.go b/server/internal/internal/testutil/testutil.go new file mode 100644 index 000000000..354c2608c --- /dev/null +++ b/server/internal/internal/testutil/testutil.go @@ -0,0 +1,74 @@ +package testutil + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// Check calls t.Fatal(err) if err is not nil. +func Check(t *testing.T, err error) { + if err != nil { + t.Helper() + t.Fatal(err) + } +} + +// CheckFunc exists so other packages do not need to invent their own type for +// taking a Check function. +type CheckFunc func(err error) + +// Checker returns a check function that +// calls t.Fatal if err is not nil. +func Checker(t *testing.T) (check func(err error)) { + return func(err error) { + if err != nil { + t.Helper() + t.Fatal(err) + } + } +} + +// StopPanic runs f but silently recovers from any panic f causes. +// The normal usage is: +// +// testutil.StopPanic(func() { +// callThatShouldPanic() +// t.Errorf("callThatShouldPanic did not panic") +// }) +func StopPanic(f func()) { + defer func() { recover() }() + f() +} + +// CheckTime calls t.Fatalf if got != want. Included in the error message is +// want.Sub(got) to help diagnose the difference, along with their values in +// UTC. +func CheckTime(t *testing.T, got, want time.Time) { + t.Helper() + if !got.Equal(want) { + t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got)) + } +} + +// WriteFile writes data to a file named name. It makes the directory if it +// doesn't exist and sets the file mode to perm. +// +// The name must be a relative path and must not contain .. or start with a /; +// otherwise WriteFile will panic. +func WriteFile[S []byte | string](t testing.TB, name string, data S) { + t.Helper() + + if filepath.IsAbs(name) { + t.Fatalf("WriteFile: name must be a relative path, got %q", name) + } + name = filepath.Clean(name) + dir := filepath.Dir(name) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(name, []byte(data), 0o644); err != nil { + t.Fatal(err) + } +} diff --git a/server/internal/manifest/manifest.go b/server/internal/manifest/manifest.go new file mode 100644 index 000000000..e020d2c02 --- /dev/null +++ b/server/internal/manifest/manifest.go @@ -0,0 +1,118 @@ +// Package manifest provides documentation for the Ollama manifest format. +// This package contains no code. +// +// # Manifests +// +// A manifest is a JSON object that describes a model. The JSON object has a +// single field "layers" which is a list of layers that make up the model. Each +// layer has the following fields: +// +// A layer is a single, logical unit of a model. Layers are stored in the cache +// as files with the name of the digest of the layer. Layers are pushed and +// pulled from the registry as blobs. +// +// A layer is represented as a JSON object with the following fields: +// +// - "digest": The digest of the layer. +// - "mediaType": The media type of the layer. +// - "size": The size of the layer in bytes. +// +// Layers are typically stored in a blob store, such as a registry, and are +// referenced by their digest. This package does not define how layers are +// stored or retrieved. +// +// # Configuration Layer +// +// The configuration of a model is represented as a layer with the media type: +// +// application/vnd.ollama.image.config; type= +// +// The "type" parameter in the media type specifies the format of the +// configuration (e.g., "safetensor" or "gguf"). +// +// There may be only one configuration layer in a model. +// +// # Template Layer +// +// The model template is a layer with the media type: +// +// application/vnd.ollama.image.template; [name=] +// +// The "name" parameter in the media type specifies the name of the template as +// for lookup at runtime. The name is optional and may be omitted. If omitted, +// the template is the default template for the model. +// +// # Tensor Layers +// +// The tensors of a model are represented as layers with the media type: +// +// application/vnd.ollama.image.tensor; name=; dtype=; shape= +// +// The "name" parameter in the media type specifies the name of the tensor as +// defined in the model's configuration and are bound only by the rules for +// names as defined in the configuration format, as represented by the +// configuration's "type". +// +// The "dtype" parameter in the media type specifies the data type of the tensor +// as a string. +// +// TODO: Define more specifically how to represent data types as strings. +// +// The "shape" parameter in the media type specifies the shape of the tensor as +// a comma-separated list of integers; one per dimension. +// +// # Tokenization Layers +// +// The tokenization of a model is represented as a layer with the media type: +// +// application/vnd.ollama.image.tokenizer +// +// The configuration of the tokenizer is represented as a layer with the media type: +// +// application/vnd.ollama.image.tokenizer.config +// +// # Miscellaneous Layers +// +// These extra layer mime types are reserved: +// +// application/vnd.ollama.image.license +// +// This layer contains one of the many licenses for the model in plain text. +// +// # Example Manifest +// +// The following is an example manifest containing a configuration, a model +// template, and two tensors (digests shortened for brevity): +// +// { +// "layers": [{ +// "digest": "sha256:a...", +// "mediaType": "application/vnd.ollama.image.config; type=safetensors", +// "size": 1234 +// },{ +// "digest": "sha256:b...", +// "mediaType": "application/vnd.ollama.image.template", +// "size": 5678 +// },{ +// "digest": "sha256:c...", +// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3", +// "size": 9012 +// },{ +// "digest": "sha256:d...", +// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6", +// "size": 3456 +// }] +// } +// +// # Legacy Media Types +// +// The appliaction/vnd.ollama.image.model media type is deprecated, but will +// remain supported for backwards compatibility, for some undefined amount of +// time. New models should use the media types defined above. +// +// # Reserved media types +// +// The media type prefix "application/vnd.ollama.image." is reserved for +// defining new media types for layers known to Ollama. Currently, all other +// prefixes are ignored by official Ollama registry clients. +package manifest