mirror of
https://github.com/ollama/ollama.git
synced 2025-03-24 00:31:55 +01:00
More validation during the safetensor creation process. Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths Add comprehensive test coverage for various paths No functionality changes for valid inputs - existing workflows remain unaffected Leverages Go 1.24's new os.Root functionality for secure containment
107 lines
2.5 KiB
Go
107 lines
2.5 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
func TestConvertFromSafetensors(t *testing.T) {
|
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
|
|
// Helper function to create a new layer and return its digest
|
|
makeTemp := func(content string) string {
|
|
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create layer: %v", err)
|
|
}
|
|
return l.Digest
|
|
}
|
|
|
|
// Create a safetensors compatible file with empty JSON content
|
|
var buf bytes.Buffer
|
|
headerSize := int64(len("{}"))
|
|
binary.Write(&buf, binary.LittleEndian, headerSize)
|
|
buf.WriteString("{}")
|
|
|
|
model := makeTemp(buf.String())
|
|
config := makeTemp(`{
|
|
"architectures": ["LlamaForCausalLM"],
|
|
"vocab_size": 32000
|
|
}`)
|
|
tokenizer := makeTemp(`{
|
|
"version": "1.0",
|
|
"truncation": null,
|
|
"padding": null,
|
|
"added_tokens": [
|
|
{
|
|
"id": 0,
|
|
"content": "<|endoftext|>",
|
|
"single_word": false,
|
|
"lstrip": false,
|
|
"rstrip": false,
|
|
"normalized": false,
|
|
"special": true
|
|
}
|
|
]
|
|
}`)
|
|
|
|
tests := []struct {
|
|
name string
|
|
filePath string
|
|
wantErr error
|
|
}{
|
|
// Invalid
|
|
{
|
|
name: "InvalidRelativePathShallow",
|
|
filePath: filepath.Join("..", "file.safetensors"),
|
|
wantErr: errFilePath,
|
|
},
|
|
{
|
|
name: "InvalidRelativePathDeep",
|
|
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
|
wantErr: errFilePath,
|
|
},
|
|
{
|
|
name: "InvalidNestedPath",
|
|
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
|
wantErr: errFilePath,
|
|
},
|
|
{
|
|
name: "AbsolutePathOutsideRoot",
|
|
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
|
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
|
},
|
|
{
|
|
name: "ValidRelativePath",
|
|
filePath: "model.safetensors",
|
|
wantErr: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create the minimum required file map for convertFromSafetensors
|
|
files := map[string]string{
|
|
tt.filePath: model,
|
|
"config.json": config,
|
|
"tokenizer.json": tokenizer,
|
|
}
|
|
|
|
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
|
|
|
if (tt.wantErr == nil && err != nil) ||
|
|
(tt.wantErr != nil && err == nil) ||
|
|
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
|
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|