From eed11ded30f1fa386632ed8922a4349fe5140096 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 28 Feb 2025 13:08:10 -0800 Subject: [PATCH] server/.../safetensors: fix offsets and include all model parts (#9427) Also, require the -as flag to be set when importing a model. This prevents the confusing error message "invalid name". Also, allow short names to be used when importing a model and auto-complete the name with the default mask. --- server/internal/client/ollama/registry.go | 17 +++++++++++++---- .../cmd/opp/internal/safetensors/safetensors.go | 8 ++++++-- server/internal/cmd/opp/opp.go | 8 +++++++- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index d4d58ed61..e4c36d7d8 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -147,14 +147,23 @@ func (e *Error) UnmarshalJSON(b []byte) error { return nil } -var defaultName = func() names.Name { - n := names.Parse("registry.ollama.ai/library/_:latest") +const DefaultMask = "registry.ollama.ai/library/_:latest" + +var defaultMask = func() names.Name { + n := names.Parse(DefaultMask) if !n.IsFullyQualified() { - panic("default name is not fully qualified") + panic("default mask is not fully qualified") } return n }() +// CompleteName returns a fully qualified name by merging the given name with +// the default mask. If the name is already fully qualified, it is returned +// unchanged. +func CompleteName(name string) string { + return names.Merge(names.Parse(name), defaultMask).String() +} + // Registry is a client for performing push and pull operations against an // Ollama registry. type Registry struct { @@ -249,7 +258,7 @@ type PushParams struct { // // The scheme is returned as provided by [names.ParseExtended]. func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) { - maskName := defaultName + maskName := defaultMask if mask != "" { maskName = names.Parse(mask) if !maskName.IsFullyQualified() { diff --git a/server/internal/cmd/opp/internal/safetensors/safetensors.go b/server/internal/cmd/opp/internal/safetensors/safetensors.go index 7f3e99798..7a45b91df 100644 --- a/server/internal/cmd/opp/internal/safetensors/safetensors.go +++ b/server/internal/cmd/opp/internal/safetensors/safetensors.go @@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) { return nil, err } + endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself + // TODO(bmizerany): do something with metadata? This could be another // header read if needed. We also need to figure out if the metadata is // present in only one .safetensors file or if each file may have their @@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) { tt := make([]*Tensor, 0, len(raws)) for name, raw := range raws { - if !strings.HasPrefix(name, "model.layer") { + if name == "__metadata__" { + // TODO(bmizerany): do something with metadata? continue } var v struct { @@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) { // TODO(bmizerany): after collecting, validate all offests make // tensors contiguous? - begin, end := v.Offsets[0], v.Offsets[1] + begin := endOfHeader + v.Offsets[0] + end := endOfHeader + v.Offsets[1] if err := checkBeginEnd(finfo.Size(), begin, end); err != nil { return nil, err } diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go index cc10a72ff..c21e71d59 100644 --- a/server/internal/cmd/opp/opp.go +++ b/server/internal/cmd/opp/opp.go @@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error { flag.PrintDefaults() } flag.Parse(args) + if *flagAs == "" { + return fmt.Errorf("missing -as flag") + } + as := ollama.CompleteName(*flagAs) dir := cmp.Or(flag.Arg(0), ".") fmt.Fprintf(os.Stderr, "Reading %s\n", dir) @@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error { if err != nil { return err } - return c.Link(*flagAs, d) + return c.Link(as, d) }() }() @@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error { writeProgress() case err := <-done: writeProgress() + fmt.Println() + fmt.Println("Successfully imported", as) return err } }