diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index c89632800..e70ffbeab 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -10,6 +10,7 @@ import ( "os" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/spf13/cobra" @@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) { } } +func TestListHandler(t *testing.T) { + tests := []struct { + name string + args []string + serverResponse []api.ListModelResponse + expectedError string + expectedOutput string + }{ + { + name: "list all models", + args: []string{}, + serverResponse: []api.ListModelResponse{ + {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)}, + }, + expectedOutput: "NAME ID SIZE MODIFIED \n" + + "model1 sha256:abc12 1.0 KB 24 hours ago \n" + + "model2 sha256:def45 2.0 KB 2 days ago \n", + }, + { + name: "filter models by prefix", + args: []string{"model1"}, + serverResponse: []api.ListModelResponse{ + {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + }, + expectedOutput: "NAME ID SIZE MODIFIED \n" + + "model1 sha256:abc12 1.0 KB 24 hours ago \n", + }, + { + name: "server error", + args: []string{}, + expectedError: "server error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tags" || r.Method != http.MethodGet { + t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + + if tt.expectedError != "" { + http.Error(w, tt.expectedError, http.StatusInternalServerError) + return + } + + response := api.ListResponse{Models: tt.serverResponse} + if err := json.NewEncoder(w).Encode(response); err != nil { + t.Fatal(err) + } + })) + defer mockServer.Close() + + t.Setenv("OLLAMA_HOST", mockServer.URL) + + cmd := &cobra.Command{} + cmd.SetContext(context.TODO()) + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := ListHandler(cmd, tt.args) + + // Restore stdout and get output + w.Close() + os.Stdout = oldStdout + output, _ := io.ReadAll(r) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if got := string(output); got != tt.expectedOutput { + t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %v", tt.expectedError, err) + } + } + }) + } +} + func TestCreateHandler(t *testing.T) { tests := []struct { name string