server/.../safetensors: fix offsets and include all model parts (#9427)

Also, require the -as flag to be set when importing a model. This
prevents the confusing error message "invalid name".

Also, allow short names to be used when importing a model and
auto-complete the name with the default mask.
This commit is contained in:
Blake Mizerany 2025-02-28 13:08:10 -08:00 committed by GitHub
parent b42aba40ed
commit eed11ded30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 7 deletions

View File

@ -147,14 +147,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
return nil
}
var defaultName = func() names.Name {
n := names.Parse("registry.ollama.ai/library/_:latest")
const DefaultMask = "registry.ollama.ai/library/_:latest"
var defaultMask = func() names.Name {
n := names.Parse(DefaultMask)
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
panic("default mask is not fully qualified")
}
return n
}()
// CompleteName returns a fully qualified name by merging the given name with
// the default mask. If the name is already fully qualified, it is returned
// unchanged.
func CompleteName(name string) string {
return names.Merge(names.Parse(name), defaultMask).String()
}
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
@ -249,7 +258,7 @@ type PushParams struct {
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
maskName := defaultName
maskName := defaultMask
if mask != "" {
maskName = names.Parse(mask)
if !maskName.IsFullyQualified() {

View File

@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
return nil, err
}
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
if name == "__metadata__" {
// TODO(bmizerany): do something with metadata?
continue
}
var v struct {
@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
begin := endOfHeader + v.Offsets[0]
end := endOfHeader + v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}

View File

@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
flag.PrintDefaults()
}
flag.Parse(args)
if *flagAs == "" {
return fmt.Errorf("missing -as flag")
}
as := ollama.CompleteName(*flagAs)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
if err != nil {
return err
}
return c.Link(*flagAs, d)
return c.Link(as, d)
}()
}()
@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
writeProgress()
case err := <-done:
writeProgress()
fmt.Println()
fmt.Println("Successfully imported", as)
return err
}
}