mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 11:27:15 +01:00
bugfix: don't include both consolidated.safetensors and model-*.safetensors (#13010)
This commit is contained in:
@@ -260,10 +260,13 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
var files []string
|
||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
|
||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
|
||||
// covers consolidated.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf16"
|
||||
@@ -855,3 +856,273 @@ func TestCreateRequestFiles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesForModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(string) error
|
||||
wantFiles []string
|
||||
wantErr bool
|
||||
expectErrType error
|
||||
}{
|
||||
{
|
||||
name: "safetensors model files",
|
||||
setup: func(dir string) error {
|
||||
files := []string{
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safetensors with consolidated files - prefers model files",
|
||||
setup: func(dir string) error {
|
||||
files := []string{
|
||||
"model-00001-of-00001.safetensors",
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"model-00001-of-00001.safetensors", // consolidated files should be excluded
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safetensors without model-.safetensors files - uses consolidated",
|
||||
setup: func(dir string) error {
|
||||
files := []string{
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pytorch model files",
|
||||
setup: func(dir string) error {
|
||||
// Create a file that will be detected as application/zip
|
||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
|
||||
files := []string{
|
||||
"pytorch_model-00001-of-00002.bin",
|
||||
"pytorch_model-00002-of-00002.bin",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
content := zipHeader
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"pytorch_model-00001-of-00002.bin",
|
||||
"pytorch_model-00002-of-00002.bin",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "consolidated pth files",
|
||||
setup: func(dir string) error {
|
||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
|
||||
files := []string{
|
||||
"consolidated.00.pth",
|
||||
"consolidated.01.pth",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
content := zipHeader
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"consolidated.00.pth",
|
||||
"consolidated.01.pth",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gguf files",
|
||||
setup: func(dir string) error {
|
||||
// Create binary content that will be detected as application/octet-stream
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
binaryContent[i] = byte(i % 256)
|
||||
}
|
||||
files := []string{
|
||||
"model.gguf",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
content := binaryContent
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"model.gguf",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bin files as gguf",
|
||||
setup: func(dir string) error {
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
binaryContent[i] = byte(i % 256)
|
||||
}
|
||||
files := []string{
|
||||
"model.bin",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
content := binaryContent
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"model.bin",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no model files found",
|
||||
setup: func(dir string) error {
|
||||
// Only create non-model files
|
||||
files := []string{"README.md", "config.json"}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
expectErrType: ErrModelNotFound,
|
||||
},
|
||||
{
|
||||
name: "invalid content type for pytorch model",
|
||||
setup: func(dir string) error {
|
||||
// Create pytorch model file with wrong content type (text instead of zip)
|
||||
files := []string{
|
||||
"pytorch_model.bin",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
content := []byte("plain text content")
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testDir := filepath.Join(tmpDir, tt.name)
|
||||
if err := os.MkdirAll(testDir, 0o755); err != nil {
|
||||
t.Fatalf("Failed to create test directory: %v", err)
|
||||
}
|
||||
|
||||
if err := tt.setup(testDir); err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
|
||||
files, err := filesForModel(testDir)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
if tt.expectErrType != nil && err != tt.expectErrType {
|
||||
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relativeFiles []string
|
||||
for _, file := range files {
|
||||
rel, err := filepath.Rel(testDir, file)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get relative path: %v", err)
|
||||
}
|
||||
relativeFiles = append(relativeFiles, rel)
|
||||
}
|
||||
|
||||
if len(relativeFiles) != len(tt.wantFiles) {
|
||||
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
|
||||
}
|
||||
|
||||
fileSet := make(map[string]bool)
|
||||
for _, file := range relativeFiles {
|
||||
fileSet[file] = true
|
||||
}
|
||||
|
||||
for _, wantFile := range tt.wantFiles {
|
||||
if !fileSet[wantFile] {
|
||||
t.Errorf("Missing expected file: %s", wantFile)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user