ollama/server/create_test.go
Bruce MacDonald bebb6823c0
server: validate local path on safetensor create (#9379)
More validation during the safetensor creation process.
Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths
Add comprehensive test coverage for various paths
No functionality changes for valid inputs - existing workflows remain unaffected
Leverages Go 1.24's new os.Root functionality for secure containment
2025-02-28 16:10:43 -08:00

107 lines
2.5 KiB
Go

package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestConvertFromSafetensors(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
return l.Digest
}
// Create a safetensors compatible file with empty JSON content
var buf bytes.Buffer
headerSize := int64(len("{}"))
binary.Write(&buf, binary.LittleEndian, headerSize)
buf.WriteString("{}")
model := makeTemp(buf.String())
config := makeTemp(`{
"architectures": ["LlamaForCausalLM"],
"vocab_size": 32000
}`)
tokenizer := makeTemp(`{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}`)
tests := []struct {
name string
filePath string
wantErr error
}{
// Invalid
{
name: "InvalidRelativePathShallow",
filePath: filepath.Join("..", "file.safetensors"),
wantErr: errFilePath,
},
{
name: "InvalidRelativePathDeep",
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
wantErr: errFilePath,
},
{
name: "InvalidNestedPath",
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
wantErr: errFilePath,
},
{
name: "AbsolutePathOutsideRoot",
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
wantErr: errFilePath, // Should fail since it's outside tmpDir
},
{
name: "ValidRelativePath",
filePath: "model.safetensors",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create the minimum required file map for convertFromSafetensors
files := map[string]string{
tt.filePath: model,
"config.json": config,
"tokenizer.json": tokenizer,
}
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
if (tt.wantErr == nil && err != nil) ||
(tt.wantErr != nil && err == nil) ||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}