diff --git a/parser/expandpath_test.go b/parser/expandpath_test.go index d27626b05..845f919cf 100644 --- a/parser/expandpath_test.go +++ b/parser/expandpath_test.go @@ -4,6 +4,7 @@ import ( "os" "os/user" "path/filepath" + "runtime" "testing" ) @@ -11,14 +12,29 @@ func TestExpandPath(t *testing.T) { mockCurrentUser := func() (*user.User, error) { return &user.User{ Username: "testuser", - HomeDir: "/home/testuser", + HomeDir: func() string { + if os.PathSeparator == '\\' { + return filepath.FromSlash("D:/home/testuser") + } + return "/home/testuser" + }(), }, nil } mockLookupUser := func(username string) (*user.User, error) { fakeUsers := map[string]string{ - "testuser": "/home/testuser", - "anotheruser": "/home/anotheruser", + "testuser": func() string { + if os.PathSeparator == '\\' { + return filepath.FromSlash("D:/home/testuser") + } + return "/home/testuser" + }(), + "anotheruser": func() string { + if os.PathSeparator == '\\' { + return filepath.FromSlash("D:/home/anotheruser") + } + return "/home/anotheruser" + }(), } if homeDir, ok := fakeUsers[username]; ok { @@ -30,30 +46,78 @@ func TestExpandPath(t *testing.T) { return nil, os.ErrNotExist } - tests := []struct { - path string - relativeDir string - expected string - windowsExpected string - shouldErr bool - }{ - {"~", "", "/home/testuser", "D:\\home\\testuser", false}, - {"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false}, - {"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false}, - {"~nonexistentuser/file.txt", "", "", "", true}, - {"relative/path/to/file", "", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false}, - {"/absolute/path/to/file", "", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false}, - {".", os.Getenv("PWD"), "", os.Getenv("PWD"), false}, - {"somefile", "somedir", filepath.Join(os.Getenv("PWD"), "somedir", "somefile"), "somedir\\somefile", false}, + pwd, err := os.Getwd() + if err != nil { + t.Fatal(err) } - for _, test := range tests { - result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) - if (err != nil) != test.shouldErr { - t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) + t.Run("unix tests", func(t *testing.T) { + if runtime.GOOS == "windows" { + return } - if result != test.expected && result != test.windowsExpected && !test.shouldErr { - t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) + + tests := []struct { + path string + relativeDir string + expected string + shouldErr bool + }{ + {"~", "", "/home/testuser", false}, + {"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", false}, + {"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", false}, + {"~nonexistentuser/file.txt", "", "", true}, + {"relative/path/to/file", "", filepath.Join(pwd, "relative/path/to/file"), false}, + {"/absolute/path/to/file", "", "/absolute/path/to/file", false}, + {"/absolute/path/to/file", "someotherdir/", "/absolute/path/to/file", false}, + {".", pwd, pwd, false}, + {".", "", pwd, false}, + {"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false}, } - } + + for _, test := range tests { + result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) + if (err != nil) != test.shouldErr { + t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) + } + + if result != test.expected && !test.shouldErr { + t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) + } + } + }) + + t.Run("windows tests", func(t *testing.T) { + if runtime.GOOS != "windows" { + return + } + + tests := []struct { + path string + relativeDir string + expected string + shouldErr bool + }{ + {"~", "", "D:\\home\\testuser", false}, + {"~/myfolder/myfile.txt", "", "D:\\home\\testuser\\myfolder\\myfile.txt", false}, + {"~anotheruser/docs/file.txt", "", "D:\\home\\anotheruser\\docs\\file.txt", false}, + {"~nonexistentuser/file.txt", "", "", true}, + {"relative\\path\\to\\file", "", filepath.Join(pwd, "relative\\path\\to\\file"), false}, + {"D:\\absolute\\path\\to\\file", "", "D:\\absolute\\path\\to\\file", false}, + {"D:\\absolute\\path\\to\\file", "someotherdir/", "D:\\absolute\\path\\to\\file", false}, + {".", pwd, pwd, false}, + {".", "", pwd, false}, + {"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false}, + } + + for _, test := range tests { + result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) + if (err != nil) != test.shouldErr { + t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) + } + + if result != test.expected && !test.shouldErr { + t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) + } + } + }) } diff --git a/parser/parser.go b/parser/parser.go index 40acf3e5f..d5df479a5 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -564,7 +564,9 @@ func isValidCommand(cmd string) bool { } func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) { - if strings.HasPrefix(path, "~") { + if filepath.IsAbs(path) || strings.HasPrefix(path, "\\") || strings.HasPrefix(path, "/") { + return filepath.Abs(path) + } else if strings.HasPrefix(path, "~") { var homeDir string if path == "~" || strings.HasPrefix(path, "~/") { diff --git a/server/create.go b/server/create.go index 5856b595c..6120c705b 100644 --- a/server/create.go +++ b/server/create.go @@ -178,12 +178,37 @@ func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isA } func detectModelTypeFromFiles(files map[string]string) string { - // todo make this more robust by actually introspecting the files for fn := range files { if strings.HasSuffix(fn, ".safetensors") { return "safetensors" - } else if strings.HasSuffix(fn, ".bin") || strings.HasSuffix(fn, ".gguf") { + } else if strings.HasSuffix(fn, ".gguf") { return "gguf" + } else { + // try to see if we can find a gguf file even without the file extension + blobPath, err := GetBlobsPath(files[fn]) + if err != nil { + slog.Error("error getting blobs path", "file", fn) + return "" + } + + f, err := os.Open(blobPath) + if err != nil { + slog.Error("error reading file", "error", err) + return "" + } + defer f.Close() + + buf := make([]byte, 4) + _, err = f.Read(buf) + if err != nil { + slog.Error("error reading file", "error", err) + return "" + } + + ct := llm.DetectGGMLType(buf) + if ct == "gguf" { + return "gguf" + } } } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 9c85eb9d5..92b9e4aa3 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -3,6 +3,7 @@ package server import ( "bytes" "cmp" + "crypto/sha256" "encoding/json" "fmt" "io" @@ -710,3 +711,100 @@ func TestCreateDetectTemplate(t *testing.T) { }) }) } + +func TestDetectModelTypeFromFiles(t *testing.T) { + t.Run("gguf file", func(t *testing.T) { + _, digest := createBinFile(t, nil, nil) + files := map[string]string{ + "model.gguf": digest, + } + + modelType := detectModelTypeFromFiles(files) + if modelType != "gguf" { + t.Fatalf("expected model type 'gguf', got %q", modelType) + } + }) + + t.Run("gguf file w/o extension", func(t *testing.T) { + _, digest := createBinFile(t, nil, nil) + files := map[string]string{ + fmt.Sprintf("%x", digest): digest, + } + + modelType := detectModelTypeFromFiles(files) + if modelType != "gguf" { + t.Fatalf("expected model type 'gguf', got %q", modelType) + } + }) + + t.Run("safetensors file", func(t *testing.T) { + files := map[string]string{ + "model.safetensors": "sha256:abc123", + } + + modelType := detectModelTypeFromFiles(files) + if modelType != "safetensors" { + t.Fatalf("expected model type 'safetensors', got %q", modelType) + } + }) + + t.Run("unsupported file type", func(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + data := []byte("12345678") + digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if err := os.MkdirAll(filepath.Join(p, "blobs"), 0o755); err != nil { + t.Fatal(err) + } + + f, err := os.Create(filepath.Join(p, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + files := map[string]string{ + "model.bin": digest, + } + + modelType := detectModelTypeFromFiles(files) + if modelType != "" { + t.Fatalf("expected empty model type for unsupported file, got %q", modelType) + } + }) + + t.Run("file with less than 4 bytes", func(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + data := []byte("123") + digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if err := os.MkdirAll(filepath.Join(p, "blobs"), 0o755); err != nil { + t.Fatal(err) + } + + f, err := os.Create(filepath.Join(p, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + files := map[string]string{ + "noext": digest, + } + + modelType := detectModelTypeFromFiles(files) + if modelType != "" { + t.Fatalf("expected empty model type for small file, got %q", modelType) + } + }) +}