mirror of
https://github.com/ollama/ollama.git
synced 2025-04-06 10:58:36 +02:00
server/internal/client/ollama: hold DiskCache on Registry (#9463)
Previously, using a Registry required a DiskCache to be passed in for use in various methods. This was a bit cumbersome, as the DiskCache is required for most operations, and the DefaultCache is used in most of those cases. This change makes the DiskCache an optional field on the Registry struct. This also changes DefaultCache to initialize on first use. This is to not burden clients with the cost of creating a new cache per use, or having to hold onto a cache for the lifetime of the Registry. Also, slip in some minor docs updates for Trace.
This commit is contained in:
parent
e41c4cbea7
commit
3519dd1c6e
@ -27,6 +27,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -73,19 +74,22 @@ const (
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
)
|
||||
|
||||
// DefaultCache returns a new disk cache for storing models. If the
|
||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
||||
// otherwise, it uses $HOME/.ollama/models.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
home = cmp.Or(home, ".")
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return blob.Open(dir)
|
||||
})
|
||||
|
||||
// DefaultCache returns the default cache used by the registry. It is
|
||||
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||
// it uses the current working directory.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||
@ -168,6 +172,10 @@ func CompleteName(name string) string {
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||
// used.
|
||||
Cache *blob.DiskCache
|
||||
|
||||
// UserAgent is the User-Agent header to send with requests to the
|
||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||
UserAgent string
|
||||
@ -206,12 +214,18 @@ type Registry struct {
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified
|
||||
// names to fully qualified names. If empty, the default mask
|
||||
// ("registry.ollama.ai/library/_:latest") is used.
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
}
|
||||
|
||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||
if r.Cache != nil {
|
||||
return r.Cache, nil
|
||||
}
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
@ -282,12 +296,17 @@ type PushParams struct {
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
p = &PushParams{}
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -403,7 +422,7 @@ func canRetry(err error) bool {
|
||||
// chunks of the specified size, and then reassembled and verified. This is
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -417,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||
}
|
||||
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exists := func(l *Layer) bool {
|
||||
info, err := c.Get(l.Digest)
|
||||
return err == nil && info.Size == l.Size
|
||||
@ -554,11 +578,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
|
||||
@ -631,12 +659,17 @@ type Layer struct {
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !d.IsValid() {
|
||||
// No digest, so resolve the manifest by name.
|
||||
d, err = c.Resolve(n.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "empty", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "single", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "null", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, c, "single", nil)
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, c, "single", nil)
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, c, "single", nil)
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "zero", nil)
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), c, "://")
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "x")
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal(c, "model")
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), c, "model")
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal(c, "model")
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, c, "single")
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "notfound")
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, c, "mixed")
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, c, "remote")
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal(c, "single")
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink(c, "single")
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal(c, "single")
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
||||
rc, _ := newClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ import (
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
|
@ -63,25 +63,28 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = func() error {
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
return cmdPull(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
return cmdPush(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
@ -99,7 +102,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, c, model)
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := rc.ResolveLocal(c, from)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
@ -27,7 +26,6 @@ import (
|
||||
// directly to the blob disk cache.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Cache *blob.DiskCache // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
||||
ok, err := s.Client.Unlink(p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
l := &Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
|
||||
|
||||
s := newTestServer(t)
|
||||
|
||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err = s.Client.ResolveLocal("smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
||||
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
|
||||
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(c, rc)
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
c, err := blob.Open(modelsDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open models dir: %v", err)
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
// This is a temporary measure to allow us to move forward,
|
||||
// surfacing any code contacting ollama.com we do not intended
|
||||
@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(c, rc)
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user