diff --git a/cmd/cmd.go b/cmd/cmd.go index b0a5e7c55..cfefa35c6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -46,9 +46,8 @@ import ( var errModelfileNotFound = errors.New("specified Modelfile wasn't found") func getModelfileName(cmd *cobra.Command) (string, error) { - fn, _ := cmd.Flags().GetString("file") + filename, _ := cmd.Flags().GetString("file") - filename := fn if filename == "" { filename = "Modelfile" } @@ -60,7 +59,7 @@ func getModelfileName(cmd *cobra.Command) (string, error) { _, err = os.Stat(absName) if err != nil { - return fn, err + return filename, err } return absName, nil @@ -100,7 +99,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - req, err := modelfile.CreateRequest() + req, err := modelfile.CreateRequest(filepath.Dir(filename)) if err != nil { return err } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 18488048f..069428bec 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -279,7 +279,7 @@ func TestGetModelfileName(t *testing.T) { name: "no modelfile specified, no modelfile exists", modelfileName: "", fileExists: false, - expectedName: "", + expectedName: "Modelfile", expectedErr: os.ErrNotExist, }, { @@ -338,8 +338,8 @@ func TestGetModelfileName(t *testing.T) { t.Fatalf("couldn't set file flag: %v", err) } } else { + expectedFilename = tt.expectedName if tt.modelfileName != "" { - expectedFilename = tt.modelfileName err := cmd.Flags().Set("file", tt.modelfileName) if err != nil { t.Fatalf("couldn't set file flag: %v", err) @@ -489,3 +489,130 @@ func TestPushHandler(t *testing.T) { }) } } + +func TestCreateHandler(t *testing.T) { + tests := []struct { + name string + modelName string + modelFile string + serverResponse map[string]func(w http.ResponseWriter, r *http.Request) + expectedError string + expectedOutput string + }{ + { + name: "successful create", + modelName: "test-model", + modelFile: "FROM foo", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/create": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + + req := api.CreateRequest{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.Name != "test-model" { + t.Errorf("expected model name 'test-model', got %s", req.Name) + } + + if req.From != "foo" { + t.Errorf("expected from 'foo', got %s", req.From) + } + + responses := []api.ProgressResponse{ + {Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}, + {Status: "writing manifest"}, + {Status: "success"}, + } + + for _, resp := range responses { + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.(http.Flusher).Flush() + } + }, + }, + expectedOutput: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler, ok := tt.serverResponse[r.URL.Path] + if !ok { + t.Errorf("unexpected request to %s", r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + handler(w, r) + })) + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + tempFile, err := os.CreateTemp("", "modelfile") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + if _, err := tempFile.WriteString(tt.modelFile); err != nil { + t.Fatal(err) + } + if err := tempFile.Close(); err != nil { + t.Fatal(err) + } + + cmd := &cobra.Command{} + cmd.Flags().String("file", "", "") + if err := cmd.Flags().Set("file", tempFile.Name()); err != nil { + t.Fatal(err) + } + + cmd.Flags().Bool("insecure", false, "") + cmd.SetContext(context.TODO()) + + // Redirect stderr to capture progress output + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + // Capture stdout for the "Model pushed" message + oldStdout := os.Stdout + outR, outW, _ := os.Pipe() + os.Stdout = outW + + err = CreateHandler(cmd, []string{tt.modelName}) + + // Restore stderr + w.Close() + os.Stderr = oldStderr + // drain the pipe + if _, err := io.ReadAll(r); err != nil { + t.Fatal(err) + } + + // Restore stdout and get output + outW.Close() + os.Stdout = oldStdout + stdout, _ := io.ReadAll(outR) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if tt.expectedOutput != "" { + if got := string(stdout); got != tt.expectedOutput { + t.Errorf("expected output %q, got %q", tt.expectedOutput, got) + } + } + } + }) + } +} diff --git a/parser/expandpath_test.go b/parser/expandpath_test.go index c51e01cbe..d27626b05 100644 --- a/parser/expandpath_test.go +++ b/parser/expandpath_test.go @@ -31,27 +31,29 @@ func TestExpandPath(t *testing.T) { } tests := []struct { - input string + path string + relativeDir string expected string windowsExpected string shouldErr bool }{ - {"~", "/home/testuser", "D:\\home\\testuser", false}, - {"~/myfolder/myfile.txt", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false}, - {"~anotheruser/docs/file.txt", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false}, - {"~nonexistentuser/file.txt", "", "", true}, - {"relative/path/to/file", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false}, - {"/absolute/path/to/file", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false}, - {".", os.Getenv("PWD"), os.Getenv("PWD"), false}, + {"~", "", "/home/testuser", "D:\\home\\testuser", false}, + {"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false}, + {"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false}, + {"~nonexistentuser/file.txt", "", "", "", true}, + {"relative/path/to/file", "", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false}, + {"/absolute/path/to/file", "", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false}, + {".", os.Getenv("PWD"), "", os.Getenv("PWD"), false}, + {"somefile", "somedir", filepath.Join(os.Getenv("PWD"), "somedir", "somefile"), "somedir\\somefile", false}, } for _, test := range tests { - result, err := expandPathImpl(test.input, mockCurrentUser, mockLookupUser) + result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) if (err != nil) != test.shouldErr { - t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.input, err != nil, test.shouldErr) + t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) } if result != test.expected && result != test.windowsExpected && !test.shouldErr { - t.Errorf("expandPathImpl(%q) = %q, want %q", test.input, result, test.expected) + t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) } } } diff --git a/parser/parser.go b/parser/parser.go index 520664de6..40acf3e5f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -39,7 +39,7 @@ func (f Modelfile) String() string { var deprecatedParameters = []string{"penalize_newline"} // CreateRequest creates a new *api.CreateRequest from an existing Modelfile -func (f Modelfile) CreateRequest() (*api.CreateRequest, error) { +func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) { req := &api.CreateRequest{} var messages []api.Message @@ -49,7 +49,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) { for _, c := range f.Commands { switch c.Name { case "model": - path, err := expandPath(c.Args) + path, err := expandPath(c.Args, relativeDir) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) { req.Files = digestMap case "adapter": - path, err := expandPath(c.Args) + path, err := expandPath(c.Args, relativeDir) if err != nil { return nil, err } @@ -563,7 +563,7 @@ func isValidCommand(cmd string) bool { } } -func expandPathImpl(path string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) { +func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) { if strings.HasPrefix(path, "~") { var homeDir string @@ -591,11 +591,13 @@ func expandPathImpl(path string, currentUserFunc func() (*user.User, error), loo } path = filepath.Join(homeDir, path) + } else { + path = filepath.Join(relativeDir, path) } return filepath.Abs(path) } -func expandPath(path string) (string, error) { - return expandPathImpl(path, user.Current, user.Lookup) +func expandPath(path, relativeDir string) (string, error) { + return expandPathImpl(path, relativeDir, user.Current, user.Lookup) } diff --git a/parser/parser_test.go b/parser/parser_test.go index 169cf10fd..429bdc64b 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -747,7 +747,7 @@ MESSAGE assistant Hi! How are you? t.Error(err) } - actual, err := p.CreateRequest() + actual, err := p.CreateRequest("") if err != nil { t.Error(err) } @@ -816,7 +816,7 @@ func TestCreateRequestFiles(t *testing.T) { t.Error(err) } - actual, err := p.CreateRequest() + actual, err := p.CreateRequest("") if err != nil { t.Error(err) }