From 630518f0d95babd19ab3b717f86a2405c3326b2d Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 14 Dec 2023 16:47:40 -0800 Subject: [PATCH] Add unit test of API routes (#1528) --- cmd/cmd.go | 7 +---- go.mod | 3 ++ server/routes.go | 72 ++++++++++++++++++++++++++++--------------- server/routes_test.go | 70 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 30 deletions(-) create mode 100644 server/routes_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index bd3db19a8..04add5ff8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1035,12 +1035,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { return err } - var origins []string - if o := os.Getenv("OLLAMA_ORIGINS"); o != "" { - origins = strings.Split(o, ",") - } - - return server.Serve(ln, origins) + return server.Serve(ln) } func getImageData(filePath string) ([]byte, error) { diff --git a/go.mod b/go.mod index fd0752d0a..1bba54f99 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,14 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.3 golang.org/x/sync v0.3.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect ) diff --git a/server/routes.go b/server/routes.go index 71c27b89c..002f5f21f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -32,6 +32,10 @@ import ( var mode string = gin.DebugMode +type Server struct { + WorkDir string +} + func init() { switch mode { case gin.DebugMode: @@ -800,27 +804,27 @@ var defaultAllowOrigins = []string{ "0.0.0.0", } -func Serve(ln net.Listener, allowOrigins []string) error { - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - // clean up unused layers and manifests - if err := PruneLayers(); err != nil { - return err - } +func NewServer() (*Server, error) { + workDir, err := os.MkdirTemp("", "ollama") + if err != nil { + return nil, err + } - manifestsPath, err := GetManifestPath() - if err != nil { - return err - } + return &Server{ + WorkDir: workDir, + }, nil +} - if err := PruneDirectory(manifestsPath); err != nil { - return err - } +func (s *Server) GenerateRoutes() http.Handler { + var origins []string + if o := os.Getenv("OLLAMA_ORIGINS"); o != "" { + origins = strings.Split(o, ",") } config := cors.DefaultConfig() config.AllowWildcard = true - config.AllowOrigins = allowOrigins + config.AllowOrigins = origins for _, allowOrigin := range defaultAllowOrigins { config.AllowOrigins = append(config.AllowOrigins, fmt.Sprintf("http://%s", allowOrigin), @@ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error { ) } - workDir, err := os.MkdirTemp("", "ollama") - if err != nil { - return err - } - defer os.RemoveAll(workDir) - r := gin.Default() r.Use( cors.New(config), func(c *gin.Context) { - c.Set("workDir", workDir) + c.Set("workDir", s.WorkDir) c.Next() }, ) @@ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error { }) } + return r +} + +func Serve(ln net.Listener) error { + if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + // clean up unused layers and manifests + if err := PruneLayers(); err != nil { + return err + } + + manifestsPath, err := GetManifestPath() + if err != nil { + return err + } + + if err := PruneDirectory(manifestsPath); err != nil { + return err + } + } + + s, err := NewServer() + if err != nil { + return err + } + r := s.GenerateRoutes() + log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) - s := &http.Server{ + srvr := &http.Server{ Handler: r, } @@ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { if loaded.runner != nil { loaded.runner.Close() } - os.RemoveAll(workDir) + os.RemoveAll(s.WorkDir) os.Exit(0) }() @@ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { } } - return s.Serve(ln) + return srvr.Serve(ln) } func waitForStream(c *gin.Context, ch chan interface{}) { diff --git a/server/routes_test.go b/server/routes_test.go new file mode 100644 index 000000000..70071b2dc --- /dev/null +++ b/server/routes_test.go @@ -0,0 +1,70 @@ +package server + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func setupServer(t *testing.T) (*Server, error) { + t.Helper() + + return NewServer() +} + +func Test_Routes(t *testing.T) { + type testCase struct { + Name string + Method string + Path string + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, resp *http.Response) + } + + testCases := []testCase{ + { + Name: "Version Handler", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, contentType, "application/json; charset=utf-8") + body, err := io.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, `{"version":"0.0.0"}`, string(body)) + }, + }, + } + + s, err := setupServer(t) + assert.Nil(t, err) + + router := s.GenerateRoutes() + + httpSrv := httptest.NewServer(router) + t.Cleanup(httpSrv.Close) + + for _, tc := range testCases { + u := httpSrv.URL + tc.Path + req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) + assert.Nil(t, err) + + if tc.Setup != nil { + tc.Setup(t, req) + } + + resp, err := httpSrv.Client().Do(req) + assert.Nil(t, err) + + if tc.Expected != nil { + tc.Expected(t, resp) + } + } + +}