mirror of
https://github.com/ollama/ollama.git
synced 2025-04-04 18:12:10 +02:00
server/internal/client/ollama: handle extended names in client/ollama (#9454)
The extended name format is a superset of the name format that only the client needs to know about, not the server or other dependents of the name package, so move the split logic into the client package. Also, take advantage of knowing about the extended name format to allow the client to use the extended name format when unlinking to verify they are unlinking the manifest with the content they intend.
This commit is contained in:
parent
af68d60a58
commit
ee048b76d4
@ -212,12 +212,16 @@ type Registry struct {
|
||||
Mask string
|
||||
}
|
||||
|
||||
func (r *Registry) completeName(name string) names.Name {
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
mask = names.Parse(r.Mask)
|
||||
}
|
||||
return names.Merge(names.Parse(name), mask)
|
||||
n := names.Merge(names.Parse(name), mask)
|
||||
if !n.IsFullyQualified() {
|
||||
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
@ -306,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
// This should never happen since ResolveLocal should have
|
||||
// already validated the name.
|
||||
@ -400,7 +404,7 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -551,9 +555,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
n := r.completeName(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
@ -626,10 +630,9 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||
// parsed using [names.Split] but the scheme is ignored.
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := parseName(name, r.Mask)
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -655,7 +658,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
|
||||
|
||||
// Resolve resolves a name to a Manifest in the remote registry.
|
||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||
scheme, n, d, err := parseName(name, r.Mask)
|
||||
scheme, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -859,7 +862,7 @@ var supportedSchemes = []string{
|
||||
|
||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||
|
||||
// parseName parses and validates an extended name, returning the scheme, name,
|
||||
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
||||
// and digest.
|
||||
//
|
||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||
@ -870,8 +873,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
|
||||
//
|
||||
// If the name is not, once merged with the mask, fully qualified,
|
||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := names.Split(s)
|
||||
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := splitExtended(s)
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
if !slices.Contains(supportedSchemes, scheme) {
|
||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||
@ -894,13 +897,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges
|
||||
}
|
||||
}
|
||||
|
||||
maskName := defaultMask
|
||||
if mask != "" {
|
||||
maskName = names.Parse(mask)
|
||||
}
|
||||
n := names.Merge(names.Parse(name), maskName)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
// splitExtended splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
// model@digest
|
||||
// @digest
|
||||
func splitExtended(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@ -91,7 +92,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
_, n, _, err := parseName(name, r.Mask)
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -709,25 +710,16 @@ func TestErrorUnmarshal(t *testing.T) {
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameErrors(t *testing.T) {
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{"x", nil, ""},
|
||||
{"x@", nil, ""},
|
||||
|
||||
{"", ErrNameInvalid, `invalid or missing name: ""`},
|
||||
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
|
||||
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
|
||||
|
||||
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
}
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := parseName(tt.name, DefaultMask)
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
@ -736,3 +728,89 @@ func TestParseNameErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal(c, "single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink(c, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal(c, "single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user