server/internal/client/ollama: handle extended names in client/ollama (#9454)

The extended name format is a superset of the name format that only the
client needs to know about, not the server or other dependents of the
name package, so move the split logic into the client package.

Also, take advantage of knowing about the extended name format to allow
the client to use the extended name format when unlinking to verify they
are unlinking the manifest with the content they intend.
This commit is contained in:
Blake Mizerany 2025-03-02 13:30:41 -08:00 committed by GitHub
parent af68d60a58
commit ee048b76d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 136 additions and 35 deletions

View File

@ -212,12 +212,16 @@ type Registry struct {
Mask string
}
func (r *Registry) completeName(name string) names.Name {
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
return names.Merge(names.Parse(name), mask)
n := names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
@ -306,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name, r.Mask)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@ -400,7 +404,7 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name, r.Mask)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
@ -551,9 +555,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
n := r.completeName(name)
if !n.IsFullyQualified() {
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
n, err := r.parseName(name)
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
@ -626,10 +630,9 @@ type Layer struct {
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.Split] but the scheme is ignored.
// ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.Mask)
_, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@ -655,7 +658,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name, r.Mask)
scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@ -859,7 +862,7 @@ var supportedSchemes = []string{
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseName parses and validates an extended name, returning the scheme, name,
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
@ -870,8 +873,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := names.Split(s)
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
@ -894,13 +897,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges
}
}
maskName := defaultMask
if mask != "" {
maskName = names.Parse(mask)
}
n := names.Merge(names.Parse(name), maskName)
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
n, err := r.parseName(name)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
return scheme, n, d, nil
}
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}

View File

@ -2,6 +2,7 @@ package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
@ -91,7 +92,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
link := func(name string, manifest string) {
_, n, _, err := parseName(name, r.Mask)
n, err := r.parseName(name)
if err != nil {
panic(err)
}
@ -709,25 +710,16 @@ func TestErrorUnmarshal(t *testing.T) {
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameErrors(t *testing.T) {
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"x", nil, ""},
{"x@", nil, ""},
{"", ErrNameInvalid, `invalid or missing name: ""`},
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
}
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := parseName(tt.name, DefaultMask)
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
@ -736,3 +728,89 @@ func TestParseNameErrors(t *testing.T) {
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, c := newClient(t, nil)
// confirm linked
_, err := rc.ResolveLocal(c, "single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// unlink
_, err = rc.Unlink(c, "single")
testutil.Check(t, err)
// confirm unlinked
_, err = rc.ResolveLocal(c, "single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, c := newClient(t, nil)
ok, err := rc.Unlink(c, "manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}