diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 82a8bbca4..7ffc16db2 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -212,12 +212,16 @@ type Registry struct { Mask string } -func (r *Registry) completeName(name string) names.Name { +func (r *Registry) parseName(name string) (names.Name, error) { mask := defaultMask if r.Mask != "" { mask = names.Parse(r.Mask) } - return names.Merge(names.Parse(name), mask) + n := names.Merge(names.Parse(name), mask) + if !n.IsFullyQualified() { + return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name) + } + return n, nil } // DefaultRegistry returns a new Registry configured from the environment. The @@ -306,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.Mask) + scheme, n, _, err := r.parseNameExtended(name) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -400,7 +404,7 @@ func canRetry(err error) bool { // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.Mask) + scheme, n, _, err := r.parseNameExtended(name) if err != nil { return err } @@ -551,9 +555,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - n := r.completeName(name) - if !n.IsFullyQualified() { - return false, fmt.Errorf("%w: %q", ErrNameInvalid, name) + n, err := r.parseName(name) + if err != nil { + return false, err } return c.Unlink(n.String()) } @@ -626,10 +630,9 @@ type Layer struct { Size int64 `json:"size"` } -// ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.Split] but the scheme is ignored. +// ResolveLocal resolves a name to a Manifest in the local cache. func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.Mask) + _, n, d, err := r.parseNameExtended(name) if err != nil { return nil, err } @@ -655,7 +658,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.Mask) + scheme, n, d, err := r.parseNameExtended(name) if err != nil { return nil, err } @@ -859,7 +862,7 @@ var supportedSchemes = []string{ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) -// parseName parses and validates an extended name, returning the scheme, name, +// parseNameExtended parses and validates an extended name, returning the scheme, name, // and digest. // // If the scheme is empty, scheme will be "https". If an unsupported scheme is @@ -870,8 +873,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo // // If the name is not, once merged with the mask, fully qualified, // [ErrNameInvalid] wrapped with a display friendly message is returned. -func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) { - scheme, name, digest := names.Split(s) +func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := splitExtended(s) scheme = cmp.Or(scheme, "https") if !slices.Contains(supportedSchemes, scheme) { err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) @@ -894,13 +897,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges } } - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - } - n := names.Merge(names.Parse(name), maskName) - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) + n, err := r.parseName(name) + if err != nil { + return "", names.Name{}, blob.Digest{}, err } return scheme, n, d, nil } + +// splitExtended splits an extended name string into its scheme, name, and digest +// parts. +// +// Examples: +// +// http://ollama.com/bmizerany/smol:latest@digest +// https://ollama.com/bmizerany/smol:latest +// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. +// model@digest +// @digest +func splitExtended(s string) (scheme, name, digest string) { + i := strings.Index(s, "://") + if i >= 0 { + scheme = s[:i] + s = s[i+3:] + } + i = strings.LastIndex(s, "@") + if i >= 0 { + digest = s[i+1:] + s = s[:i] + } + return scheme, s, digest +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 20a1f1593..92b53637f 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -2,6 +2,7 @@ package ollama import ( "bytes" + "cmp" "context" "encoding/json" "errors" @@ -91,7 +92,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } link := func(name string, manifest string) { - _, n, _, err := parseName(name, r.Mask) + n, err := r.parseName(name) if err != nil { panic(err) } @@ -709,25 +710,16 @@ func TestErrorUnmarshal(t *testing.T) { // // It is only for testing error messages, not that all invalids and valids are // covered. Those are in other tests for names.Name and blob.Digest. -func TestParseNameErrors(t *testing.T) { +func TestParseNameExtendedErrors(t *testing.T) { cases := []struct { name string err error want string - }{ - {"x", nil, ""}, - {"x@", nil, ""}, - - {"", ErrNameInvalid, `invalid or missing name: ""`}, - {"://", ErrNameInvalid, `invalid or missing name: "://"`}, - {"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`}, - - {"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, - {"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, - } + }{} + var r Registry for _, tt := range cases { - _, _, _, err := parseName(tt.name, DefaultMask) + _, _, _, err := r.parseNameExtended(tt.name) if !errors.Is(err, tt.err) { t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) } @@ -736,3 +728,89 @@ func TestParseNameErrors(t *testing.T) { } } } + +func TestParseNameExtended(t *testing.T) { + cases := []struct { + in string + scheme string + name string + digest string + err string + }{ + {in: "http://m", scheme: "http", name: "m"}, + {in: "https+insecure://m", scheme: "https+insecure", name: "m"}, + {in: "http+insecure://m", err: "unsupported scheme"}, + + {in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"}, + + {in: "", err: "invalid or missing name"}, + {in: "m", scheme: "https", name: "m"}, + {in: "://", err: "invalid or missing name"}, + {in: "@sha256:deadbeef", err: "invalid digest"}, + {in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + var r Registry + scheme, n, digest, err := r.parseNameExtended(tt.in) + if err != nil { + if tt.err == "" { + t.Errorf("err = %v; want nil", err) + } else if !strings.Contains(err.Error(), tt.err) { + t.Errorf("err = %v; want %q", err, tt.err) + } + } else if tt.err != "" { + t.Errorf("err = nil; want %q", tt.err) + } + if err == nil && !n.IsFullyQualified() { + t.Errorf("name = %q; want fully qualified", n) + } + + if scheme != tt.scheme { + t.Errorf("scheme = %q; want %q", scheme, tt.scheme) + } + + // smoke-test name is superset of tt.name + if !strings.Contains(n.String(), tt.name) { + t.Errorf("name = %q; want %q", n, tt.name) + } + + tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String()) + if digest.String() != tt.digest { + t.Errorf("digest = %q; want %q", digest, tt.digest) + } + }) + } +} + +func TestUnlink(t *testing.T) { + t.Run("found by name", func(t *testing.T) { + rc, c := newClient(t, nil) + + // confirm linked + _, err := rc.ResolveLocal(c, "single") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // unlink + _, err = rc.Unlink(c, "single") + testutil.Check(t, err) + + // confirm unlinked + _, err = rc.ResolveLocal(c, "single") + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("err = %v; want fs.ErrNotExist", err) + } + }) + t.Run("not found by name", func(t *testing.T) { + rc, c := newClient(t, nil) + ok, err := rc.Unlink(c, "manifestNotFound") + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("expected not found") + } + }) +}