bugfix: don't include both consolidated.safetensors and model-*.safetensors (#13010)

This commit is contained in:
Patrick Devine
2025-11-07 22:41:57 -08:00
committed by GitHub
parent 755ac3b069
commit 91ec3ddbeb
2 changed files with 275 additions and 1 deletions

View File

@@ -260,10 +260,13 @@ func filesForModel(path string) ([]string, error) {
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
// covers consolidated.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin

View File

@@ -9,6 +9,7 @@ import (
"io"
"maps"
"os"
"path/filepath"
"strings"
"testing"
"unicode/utf16"
@@ -855,3 +856,273 @@ func TestCreateRequestFiles(t *testing.T) {
}
}
}
func TestFilesForModel(t *testing.T) {
tests := []struct {
name string
setup func(string) error
wantFiles []string
wantErr bool
expectErrType error
}{
{
name: "safetensors model files",
setup: func(dir string) error {
files := []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"config.json",
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"config.json",
"tokenizer.json",
},
},
{
name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error {
files := []string{
"model-00001-of-00001.safetensors",
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"model-00001-of-00001.safetensors", // consolidated files should be excluded
"config.json",
},
},
{
name: "safetensors without model-.safetensors files - uses consolidated",
setup: func(dir string) error {
files := []string{
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"consolidated.safetensors",
"config.json",
},
},
{
name: "pytorch model files",
setup: func(dir string) error {
// Create a file that will be detected as application/zip
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
files := []string{
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
"config.json",
}
for _, file := range files {
content := zipHeader
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
"config.json",
},
},
{
name: "consolidated pth files",
setup: func(dir string) error {
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
files := []string{
"consolidated.00.pth",
"consolidated.01.pth",
"config.json",
}
for _, file := range files {
content := zipHeader
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"consolidated.00.pth",
"consolidated.01.pth",
"config.json",
},
},
{
name: "gguf files",
setup: func(dir string) error {
// Create binary content that will be detected as application/octet-stream
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
}
files := []string{
"model.gguf",
"config.json",
}
for _, file := range files {
content := binaryContent
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"model.gguf",
"config.json",
},
},
{
name: "bin files as gguf",
setup: func(dir string) error {
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
}
files := []string{
"model.bin",
"config.json",
}
for _, file := range files {
content := binaryContent
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
}
}
return nil
},
wantFiles: []string{
"model.bin",
"config.json",
},
},
{
name: "no model files found",
setup: func(dir string) error {
// Only create non-model files
files := []string{"README.md", "config.json"}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
return err
}
}
return nil
},
wantErr: true,
expectErrType: ErrModelNotFound,
},
{
name: "invalid content type for pytorch model",
setup: func(dir string) error {
// Create pytorch model file with wrong content type (text instead of zip)
files := []string{
"pytorch_model.bin",
"config.json",
}
for _, file := range files {
content := []byte("plain text content")
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
}
}
return nil
},
wantErr: true,
},
}
tmpDir := t.TempDir()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testDir := filepath.Join(tmpDir, tt.name)
if err := os.MkdirAll(testDir, 0o755); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
if err := tt.setup(testDir); err != nil {
t.Fatalf("Setup failed: %v", err)
}
files, err := filesForModel(testDir)
if tt.wantErr {
if err == nil {
t.Error("Expected error, but got none")
}
if tt.expectErrType != nil && err != tt.expectErrType {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
var relativeFiles []string
for _, file := range files {
rel, err := filepath.Rel(testDir, file)
if err != nil {
t.Fatalf("Failed to get relative path: %v", err)
}
relativeFiles = append(relativeFiles, rel)
}
if len(relativeFiles) != len(tt.wantFiles) {
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
}
fileSet := make(map[string]bool)
for _, file := range relativeFiles {
fileSet[file] = true
}
for _, wantFile := range tt.wantFiles {
if !fileSet[wantFile] {
t.Errorf("Missing expected file: %s", wantFile)
}
}
})
}
}