make the modelfile path relative for ollama create (#8380)

This commit is contained in:
Patrick Devine 2025-01-10 16:14:08 -08:00 committed by GitHub
parent 9446c2c902
commit 32bd37adf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 25 deletions

View File

@ -46,9 +46,8 @@ import (
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
fn, _ := cmd.Flags().GetString("file")
filename, _ := cmd.Flags().GetString("file")
filename := fn
if filename == "" {
filename = "Modelfile"
}
@ -60,7 +59,7 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
_, err = os.Stat(absName)
if err != nil {
return fn, err
return filename, err
}
return absName, nil
@ -100,7 +99,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
req, err := modelfile.CreateRequest()
req, err := modelfile.CreateRequest(filepath.Dir(filename))
if err != nil {
return err
}

View File

@ -279,7 +279,7 @@ func TestGetModelfileName(t *testing.T) {
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedName: "Modelfile",
expectedErr: os.ErrNotExist,
},
{
@ -338,8 +338,8 @@ func TestGetModelfileName(t *testing.T) {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
expectedFilename = tt.expectedName
if tt.modelfileName != "" {
expectedFilename = tt.modelfileName
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
@ -489,3 +489,130 @@ func TestPushHandler(t *testing.T) {
})
}
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string
modelName string
modelFile string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful create",
modelName: "test-model",
modelFile: "FROM foo",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/create": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
req := api.CreateRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
if req.From != "foo" {
t.Errorf("expected from 'foo', got %s", req.From)
}
responses := []api.ProgressResponse{
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
{Status: "writing manifest"},
{Status: "success"},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler, ok := tt.serverResponse[r.URL.Path]
if !ok {
t.Errorf("unexpected request to %s", r.URL.Path)
http.Error(w, "not found", http.StatusNotFound)
return
}
handler(w, r)
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp("", "modelfile")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
t.Fatal(err)
}
if err := tempFile.Close(); err != nil {
t.Fatal(err)
}
cmd := &cobra.Command{}
cmd.Flags().String("file", "", "")
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
t.Fatal(err)
}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err = CreateHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
}
})
}
}

View File

@ -31,27 +31,29 @@ func TestExpandPath(t *testing.T) {
}
tests := []struct {
input string
path string
relativeDir string
expected string
windowsExpected string
shouldErr bool
}{
{"~", "/home/testuser", "D:\\home\\testuser", false},
{"~/myfolder/myfile.txt", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
{"~anotheruser/docs/file.txt", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false},
{"~nonexistentuser/file.txt", "", "", true},
{"relative/path/to/file", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false},
{"/absolute/path/to/file", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false},
{".", os.Getenv("PWD"), os.Getenv("PWD"), false},
{"~", "", "/home/testuser", "D:\\home\\testuser", false},
{"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
{"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false},
{"~nonexistentuser/file.txt", "", "", "", true},
{"relative/path/to/file", "", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false},
{"/absolute/path/to/file", "", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false},
{".", os.Getenv("PWD"), "", os.Getenv("PWD"), false},
{"somefile", "somedir", filepath.Join(os.Getenv("PWD"), "somedir", "somefile"), "somedir\\somefile", false},
}
for _, test := range tests {
result, err := expandPathImpl(test.input, mockCurrentUser, mockLookupUser)
result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser)
if (err != nil) != test.shouldErr {
t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.input, err != nil, test.shouldErr)
t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr)
}
if result != test.expected && result != test.windowsExpected && !test.shouldErr {
t.Errorf("expandPathImpl(%q) = %q, want %q", test.input, result, test.expected)
t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected)
}
}
}

View File

@ -39,7 +39,7 @@ func (f Modelfile) String() string {
var deprecatedParameters = []string{"penalize_newline"}
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
req := &api.CreateRequest{}
var messages []api.Message
@ -49,7 +49,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args)
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
@ -64,7 +64,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
req.Files = digestMap
case "adapter":
path, err := expandPath(c.Args)
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
@ -563,7 +563,7 @@ func isValidCommand(cmd string) bool {
}
}
func expandPathImpl(path string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
if strings.HasPrefix(path, "~") {
var homeDir string
@ -591,11 +591,13 @@ func expandPathImpl(path string, currentUserFunc func() (*user.User, error), loo
}
path = filepath.Join(homeDir, path)
} else {
path = filepath.Join(relativeDir, path)
}
return filepath.Abs(path)
}
func expandPath(path string) (string, error) {
return expandPathImpl(path, user.Current, user.Lookup)
func expandPath(path, relativeDir string) (string, error) {
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
}

View File

@ -747,7 +747,7 @@ MESSAGE assistant Hi! How are you?
t.Error(err)
}
actual, err := p.CreateRequest()
actual, err := p.CreateRequest("")
if err != nil {
t.Error(err)
}
@ -816,7 +816,7 @@ func TestCreateRequestFiles(t *testing.T) {
t.Error(err)
}
actual, err := p.CreateRequest()
actual, err := p.CreateRequest("")
if err != nil {
t.Error(err)
}