mirror of
https://github.com/ollama/ollama.git
synced 2025-04-04 09:58:31 +02:00
server/.../safetensors: fix offsets and include all model parts (#9427)
Also, require the -as flag to be set when importing a model. This prevents the confusing error message "invalid name". Also, allow short names to be used when importing a model and auto-complete the name with the default mask.
This commit is contained in:
parent
b42aba40ed
commit
eed11ded30
@ -147,14 +147,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultName = func() names.Name {
|
||||
n := names.Parse("registry.ollama.ai/library/_:latest")
|
||||
const DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
|
||||
var defaultMask = func() names.Name {
|
||||
n := names.Parse(DefaultMask)
|
||||
if !n.IsFullyQualified() {
|
||||
panic("default name is not fully qualified")
|
||||
panic("default mask is not fully qualified")
|
||||
}
|
||||
return n
|
||||
}()
|
||||
|
||||
// CompleteName returns a fully qualified name by merging the given name with
|
||||
// the default mask. If the name is already fully qualified, it is returned
|
||||
// unchanged.
|
||||
func CompleteName(name string) string {
|
||||
return names.Merge(names.Parse(name), defaultMask).String()
|
||||
}
|
||||
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
@ -249,7 +258,7 @@ type PushParams struct {
|
||||
//
|
||||
// The scheme is returned as provided by [names.ParseExtended].
|
||||
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
|
||||
maskName := defaultName
|
||||
maskName := defaultMask
|
||||
if mask != "" {
|
||||
maskName = names.Parse(mask)
|
||||
if !maskName.IsFullyQualified() {
|
||||
|
@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||
|
||||
// TODO(bmizerany): do something with metadata? This could be another
|
||||
// header read if needed. We also need to figure out if the metadata is
|
||||
// present in only one .safetensors file or if each file may have their
|
||||
@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
tt := make([]*Tensor, 0, len(raws))
|
||||
for name, raw := range raws {
|
||||
if !strings.HasPrefix(name, "model.layer") {
|
||||
if name == "__metadata__" {
|
||||
// TODO(bmizerany): do something with metadata?
|
||||
continue
|
||||
}
|
||||
var v struct {
|
||||
@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
// TODO(bmizerany): after collecting, validate all offests make
|
||||
// tensors contiguous?
|
||||
begin, end := v.Offsets[0], v.Offsets[1]
|
||||
begin := endOfHeader + v.Offsets[0]
|
||||
end := endOfHeader + v.Offsets[1]
|
||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
if *flagAs == "" {
|
||||
return fmt.Errorf("missing -as flag")
|
||||
}
|
||||
as := ollama.CompleteName(*flagAs)
|
||||
|
||||
dir := cmp.Or(flag.Arg(0), ".")
|
||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||
@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Link(*flagAs, d)
|
||||
return c.Link(as, d)
|
||||
}()
|
||||
}()
|
||||
|
||||
@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
writeProgress()
|
||||
case err := <-done:
|
||||
writeProgress()
|
||||
fmt.Println()
|
||||
fmt.Println("Successfully imported", as)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user