From b48b6f85cd01d3273cc9c8c5c7a489f1121794e0 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Sun, 2 Mar 2025 15:20:58 -0800 Subject: [PATCH] server/internal/client/ollama: hold DiskCache on Registry Previously, clients of a Registry had to carry around a DiskCache to use it. This change makes the DiskCache an optional field on the Registry struct. This also changes DefaultCache to initialize one on first use. This prevents overhead of building the cache if it is never used, or per Registry request that involves use of DefaultCache. Also, slip in some minor docs on Trace. --- server/internal/client/ollama/registry.go | 69 +++++++++--- .../internal/client/ollama/registry_test.go | 102 +++++++++--------- server/internal/client/ollama/trace.go | 3 + server/internal/cmd/opp/opp.go | 39 +++---- server/internal/registry/server.go | 4 +- server/internal/registry/server_test.go | 6 +- server/routes.go | 10 +- server/routes_test.go | 8 +- 8 files changed, 136 insertions(+), 105 deletions(-) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 7ffc16db2..53ec4e795 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -27,6 +27,7 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -73,19 +74,22 @@ const ( DefaultMaxChunkSize = 8 << 20 ) -// DefaultCache returns a new disk cache for storing models. If the -// OLLAMA_MODELS environment variable is set, it uses that directory; -// otherwise, it uses $HOME/.ollama/models. -func DefaultCache() (*blob.DiskCache, error) { +var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { dir := os.Getenv("OLLAMA_MODELS") if dir == "" { - home, err := os.UserHomeDir() - if err != nil { - return nil, err - } + home, _ := os.UserHomeDir() + home = cmp.Or(home, ".") dir = filepath.Join(home, ".ollama", "models") } return blob.Open(dir) +}) + +// DefaultCache returns the default cache used by the registry. It is +// configured from the OLLAMA_MODELS environment variable, or defaults to +// $HOME/.ollama/models, or, if an error occurs obtaining the home directory, +// it uses the current working directory. +func DefaultCache() (*blob.DiskCache, error) { + return defaultCache() } // Error is the standard error returned by Ollama APIs. It can represent a @@ -168,6 +172,10 @@ func CompleteName(name string) string { // Registry is a client for performing push and pull operations against an // Ollama registry. type Registry struct { + // Cache is the cache used to store models. If nil, [DefaultCache] is + // used. + Cache *blob.DiskCache + // UserAgent is the User-Agent header to send with requests to the // registry. If empty, the User-Agent is determined by HTTPClient. UserAgent string @@ -206,12 +214,18 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // Mask, if set, is the name used to convert non-fully qualified - // names to fully qualified names. If empty, the default mask - // ("registry.ollama.ai/library/_:latest") is used. + // Mask, if set, is the name used to convert non-fully qualified names + // to fully qualified names. If empty, [DefaultMask] is used. Mask string } +func (r *Registry) cache() (*blob.DiskCache, error) { + if r.Cache != nil { + return r.Cache, nil + } + return defaultCache() +} + func (r *Registry) parseName(name string) (names.Name, error) { mask := defaultMask if r.Mask != "" { @@ -241,6 +255,10 @@ func DefaultRegistry() (*Registry, error) { } var rc Registry + rc.Cache, err = defaultCache() + if err != nil { + return nil, err + } rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) if err != nil { return nil, err @@ -282,12 +300,17 @@ type PushParams struct { } // Push pushes the model with the name in the cache to the remote registry. -func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { +func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { if p == nil { p = &PushParams{} } - m, err := r.ResolveLocal(c, cmp.Or(p.From, name)) + c, err := r.cache() + if err != nil { + return err + } + + m, err := r.ResolveLocal(cmp.Or(p.From, name)) if err != nil { return err } @@ -403,7 +426,7 @@ func canRetry(err error) bool { // chunks of the specified size, and then reassembled and verified. This is // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". -func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { +func (r *Registry) Pull(ctx context.Context, name string) error { scheme, n, _, err := r.parseNameExtended(name) if err != nil { return err @@ -417,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err return fmt.Errorf("%w: no layers", ErrManifestInvalid) } + c, err := r.cache() + if err != nil { + return err + } + exists := func(l *Layer) bool { info, err := c.Get(l.Digest) return err == nil && info.Size == l.Size @@ -554,11 +582,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. -func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { +func (r *Registry) Unlink(name string) (ok bool, _ error) { n, err := r.parseName(name) if err != nil { return false, err } + c, err := r.cache() + if err != nil { + return false, err + } return c.Unlink(n.String()) } @@ -631,12 +663,17 @@ type Layer struct { } // ResolveLocal resolves a name to a Manifest in the local cache. -func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { +func (r *Registry) ResolveLocal(name string) (*Manifest, error) { _, n, d, err := r.parseNameExtended(name) if err != nil { return nil, err } + c, err := r.cache() + if err != nil { + return nil, err + } if !d.IsValid() { + // No digest, so resolve the manifest by name. d, err = c.Resolve(n.String()) if err != nil { return nil, err diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 92b53637f..b9b4271b9 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error // To simulate a network error, pass a handler that returns a 499 status code. func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { t.Helper() + c, err := blob.Open(t.TempDir()) if err != nil { t.Fatal(err) @@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } r := &Registry{ + Cache: c, HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, @@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { } func TestPushZero(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "empty", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "empty", nil) if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want %v", err, ErrManifestInvalid) } } func TestPushSingle(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "single", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "single", nil) testutil.Check(t, err) } func TestPushMultiple(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "multiple", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "multiple", nil) testutil.Check(t, err) } func TestPushNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Errorf("unexpected request: %v", r) }) - err := rc.Push(t.Context(), c, "notfound", nil) + err := rc.Push(t.Context(), "notfound", nil) if !errors.Is(err, fs.ErrNotExist) { t.Errorf("err = %v; want %v", err, fs.ErrNotExist) } } func TestPushNullLayer(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "null", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "null", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } } func TestPushSizeMismatch(t *testing.T) { - rc, c := newClient(t, nil) + rc, _ := newClient(t, nil) ctx, _ := withTraceUnexpected(t.Context()) - got := rc.Push(ctx, c, "sizemismatch", nil) + got := rc.Push(ctx, "sizemismatch", nil) if got == nil || !strings.Contains(got.Error(), "size mismatch") { t.Errorf("err = %v; want size mismatch", got) } } func TestPushInvalid(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "invalid", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "invalid", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } @@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) { func TestPushExistsAtRemote(t *testing.T) { var pushed bool - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/uploads/") { if !pushed { // First push. Return an uploadURL. @@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) { check := testutil.Checker(t) - err := rc.Push(ctx, c, "single", nil) + err := rc.Push(ctx, "single", nil) check(err) if !errors.Is(errors.Join(errs...), nil) { t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) } - err = rc.Push(ctx, c, "single", nil) + err = rc.Push(ctx, "single", nil) check(err) } func TestPushRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(500) io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) return } }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) checkErrCode(t, got, 500, "blob_error") } func TestPushLocationError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Location", ":///x") w.WriteHeader(http.StatusAccepted) }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) wantContains := "invalid upload URL" if got == nil || !strings.Contains(got.Error(), wantContains) { t.Errorf("err = %v; want to contain %v", got, wantContains) @@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) { } func TestPushUploadRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if r.Host == "blob.store" { w.WriteHeader(499) // force RoundTrip error on upload return } w.Header().Set("Location", "http://blob.store/blobs/123") }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) if !errors.Is(got, errRoundTrip) { t.Errorf("got = %v; want %v", got, errRoundTrip) } @@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) { os.Remove(c.GetFile(l.Digest)) }, }) - got := rc.Push(ctx, c, "single", nil) + got := rc.Push(ctx, "single", nil) if !errors.Is(got, fs.ErrNotExist) { t.Errorf("got = %v; want fs.ErrNotExist", got) } } func TestPushCommitRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { panic("unexpected") } w.WriteHeader(499) // force RoundTrip error }) - err := rc.Push(t.Context(), c, "zero", nil) + err := rc.Push(t.Context(), "zero", nil) if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) { } func TestRegistryPullInvalidName(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Pull(t.Context(), c, "://") + rc, _ := newClient(t, nil) + err := rc.Pull(t.Context(), "://") if !errors.Is(err, ErrNameInvalid) { t.Errorf("err = %v; want %v", err, ErrNameInvalid) } @@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) { } for _, resp := range cases { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, resp) }) - err := rc.Pull(t.Context(), c, "x") + err := rc.Pull(t.Context(), "x") if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want invalid manifest", err) } @@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) { }) // Confirm that the layer does not exist locally - _, err := rc.ResolveLocal(c, "model") + _, err := rc.ResolveLocal("model") checkNotExist(t, err) _, err = c.Get(d) checkNotExist(t, err) - err = rc.Pull(t.Context(), c, "model") + err = rc.Pull(t.Context(), "model") check(err) mw, err := rc.Resolve(t.Context(), "model") check(err) - mg, err := rc.ResolveLocal(c, "model") + mg, err := rc.ResolveLocal("model") check(err) if !reflect.DeepEqual(mw, mg) { t.Errorf("mw = %v; mg = %v", mw, mg) @@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) { func TestRegistryPullCached(t *testing.T) { cached := blob.DigestFromBytes("exists") - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(499) // should not be called return @@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - err := rc.Pull(ctx, c, "single") + err := rc.Pull(ctx, "single") testutil.Check(t, err) want := []int64{6} @@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) { } func TestRegistryPullManifestNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) - err := rc.Pull(t.Context(), c, "notfound") + err := rc.Pull(t.Context(), "notfound") checkErrCode(t, err, 404, "") } func TestRegistryPullResolveRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") checkErrCode(t, err, 500, "an_error") } func TestRegistryPullResolveRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/manifests/") { w.WriteHeader(499) // force RoundTrip error return } }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { // Check that we pull all layers that we can. - err := rc.Pull(ctx, c, "mixed") + err := rc.Pull(ctx, "mixed") if err != nil { t.Fatal(err) } @@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { } func TestRegistryPullChunking(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) if r.URL.Host != "blob.store" { // The production registry redirects to the blob store. @@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) { }, }) - err := rc.Pull(ctx, c, "remote") + err := rc.Pull(ctx, "remote") testutil.Check(t, err) want := []int64{0, 3, 6} @@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) { func TestUnlink(t *testing.T) { t.Run("found by name", func(t *testing.T) { - rc, c := newClient(t, nil) + rc, _ := newClient(t, nil) // confirm linked - _, err := rc.ResolveLocal(c, "single") + _, err := rc.ResolveLocal("single") if err != nil { t.Errorf("unexpected error: %v", err) } // unlink - _, err = rc.Unlink(c, "single") + _, err = rc.Unlink("single") testutil.Check(t, err) // confirm unlinked - _, err = rc.ResolveLocal(c, "single") + _, err = rc.ResolveLocal("single") if !errors.Is(err, fs.ErrNotExist) { t.Errorf("err = %v; want fs.ErrNotExist", err) } }) t.Run("not found by name", func(t *testing.T) { - rc, c := newClient(t, nil) - ok, err := rc.Unlink(c, "manifestNotFound") + rc, _ := newClient(t, nil) + ok, err := rc.Unlink("manifestNotFound") if err != nil { t.Fatal(err) } diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index 8e53040ad..e300870bb 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -6,6 +6,9 @@ import ( // Trace is a set of functions that are called to report progress during blob // downloads and uploads. +// +// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push] +// and [Registry.Pull]. type Trace struct { // Update is called during [Registry.Push] and [Registry.Pull] to // report the progress of blob uploads and downloads. diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go index c21e71d59..6976927c7 100644 --- a/server/internal/cmd/opp/opp.go +++ b/server/internal/cmd/opp/opp.go @@ -63,25 +63,28 @@ func main() { } flag.Parse() - c, err := ollama.DefaultCache() - if err != nil { - log.Fatal(err) - } - - rc, err := ollama.DefaultRegistry() - if err != nil { - log.Fatal(err) - } - ctx := context.Background() - err = func() error { + err := func() error { switch cmd := flag.Arg(0); cmd { case "pull": - return cmdPull(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + + return cmdPull(ctx, rc) case "push": - return cmdPush(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + return cmdPush(ctx, rc) case "import": + c, err := ollama.DefaultCache() + if err != nil { + log.Fatal(err) + } return cmdImport(ctx, c) default: if cmd == "" { @@ -99,7 +102,7 @@ func main() { } } -func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPull(ctx context.Context, rc *ollama.Registry) error { model := flag.Arg(1) if model == "" { flag.Usage() @@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error errc := make(chan error) go func() { - errc <- rc.Pull(ctx, c, model) + errc <- rc.Pull(ctx, model) }() t := time.NewTicker(time.Second) @@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } } -func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPush(ctx context.Context, rc *ollama.Registry) error { args := flag.Args()[1:] flag := flag.NewFlagSet("push", flag.ExitOnError) flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") @@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } from := cmp.Or(*flagFrom, model) - m, err := rc.ResolveLocal(c, from) + m, err := rc.ResolveLocal(from) if err != nil { return err } @@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error }, }) - return rc.Push(ctx, c, model, &ollama.PushParams{ + return rc.Push(ctx, model, &ollama.PushParams{ From: from, }) } diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 6ea590a70..81085357d 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -11,7 +11,6 @@ import ( "log/slog" "net/http" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" ) @@ -27,7 +26,6 @@ import ( // directly to the blob disk cache. type Local struct { Client *ollama.Registry // required - Cache *blob.DiskCache // required Logger *slog.Logger // required // Fallback, if set, is used to handle requests that are not handled by @@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { if err != nil { return err } - ok, err := s.Client.Unlink(s.Cache, p.model()) + ok, err := s.Client.Unlink(p.model()) if err != nil { return err } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 7ba13d501..e44d88c0f 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local { t.Fatal(err) } rc := &ollama.Registry{ + Cache: c, HTTPClient: panicOnRoundTrip, } l := &Local{ - Cache: c, Client: rc, Logger: testutil.Slogger(t), } @@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) { s := newTestServer(t) - _, err := s.Client.ResolveLocal(s.Cache, "smol") + _, err := s.Client.ResolveLocal("smol") check(err) got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`) @@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) { t.Fatalf("Code = %d; want 200", got.Code) } - _, err = s.Client.ResolveLocal(s.Cache, "smol") + _, err = s.Client.ResolveLocal("smol") if err == nil { t.Fatal("expected smol to have been deleted") } diff --git a/server/routes.go b/server/routes.go index ff42000f8..519f04738 100644 --- a/server/routes.go +++ b/server/routes.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } } -func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) { +func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { corsConfig := cors.DefaultConfig() corsConfig.AllowWildcard = true corsConfig.AllowBrowserExtensions = true @@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha // wrap old with new rs := ®istry.Local{ - Cache: c, Client: rc, Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() Fallback: r, @@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} - c, err := ollama.DefaultCache() - if err != nil { - return err - } rc, err := ollama.DefaultRegistry() if err != nil { return err } - h, err := s.GenerateRoutes(c, rc) + h, err := s.GenerateRoutes(rc) if err != nil { return err } diff --git a/server/routes_test.go b/server/routes_test.go index 0dd782f4f..e13c4b599 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -23,7 +23,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) { modelsDir := t.TempDir() t.Setenv("OLLAMA_MODELS", modelsDir) - c, err := blob.Open(modelsDir) - if err != nil { - t.Fatalf("failed to open models dir: %v", err) - } - rc := &ollama.Registry{ // This is a temporary measure to allow us to move forward, // surfacing any code contacting ollama.com we do not intended @@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) { } s := &Server{} - router, err := s.GenerateRoutes(c, rc) + router, err := s.GenerateRoutes(rc) if err != nil { t.Fatalf("failed to generate routes: %v", err) }