From 246ea2886461320e1dda8218652f8c53ced0d24b Mon Sep 17 00:00:00 2001 From: Viktor Sokolov Date: Thu, 11 Sep 2025 10:15:31 +0200 Subject: [PATCH] TestServer, AllowNetworks -> http.Transport --- auximageprovider/static_provider_test.go | 85 +++---- config/config.go | 6 +- fetcher/errors.go | 13 +- fetcher/transport/generichttp/config.go | 17 +- fetcher/transport/generichttp/errors.go | 23 ++ fetcher/transport/generichttp/generic_http.go | 30 ++- .../generichttp/generic_http_test.go | 17 +- handlers/processing/handler_test.go | 1 - handlers/stream/handler_test.go | 222 +++++++----------- imagedata/image_data_test.go | 194 +++++++-------- integration/processing_handler_test.go | 26 +- options/processing_options.go | 4 +- options/processing_options_test.go | 2 +- security/errors.go | 13 - security/source.go | 29 --- testutil/lasy_suite.go | 2 +- testutil/lazy_obj.go | 4 +- testutil/test_data_provider.go | 3 - testutil/test_server.go | 122 ++++++++++ 19 files changed, 425 insertions(+), 388 deletions(-) create mode 100644 fetcher/transport/generichttp/errors.go rename security/source_test.go => fetcher/transport/generichttp/generic_http_test.go (82%) delete mode 100644 handlers/processing/handler_test.go create mode 100644 testutil/test_server.go diff --git a/auximageprovider/static_provider_test.go b/auximageprovider/static_provider_test.go index 08cb273e..d71a8554 100644 --- a/auximageprovider/static_provider_test.go +++ b/auximageprovider/static_provider_test.go @@ -3,77 +3,56 @@ package auximageprovider import ( "encoding/base64" "io" - "net/http" - "net/http/httptest" - "os" "strconv" "testing" "github.com/stretchr/testify/suite" - "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/fetcher" "github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/options" + "github.com/imgproxy/imgproxy/v3/testutil" ) type ImageProviderTestSuite struct { - suite.Suite + testutil.LazySuite - server *httptest.Server testData []byte testDataB64 string - // Server state - status int - data []byte - header http.Header + testServer testutil.LazyTestServer + idf *imagedata.Factory } func (s *ImageProviderTestSuite) SetupSuite() { - config.Reset() - config.AllowLoopbackSourceAddresses = true + s.testData = testutil.NewTestDataProvider(s.T).Read("test1.jpg") + s.testDataB64 = base64.StdEncoding.EncodeToString(s.testData) - // Load test image data - f, err := os.Open("../testdata/test1.jpg") - s.Require().NoError(err) - defer f.Close() + fc := fetcher.NewDefaultConfig() + fc.Transport.HTTP.AllowLoopbackSourceAddresses = true - data, err := io.ReadAll(f) + f, err := fetcher.New(&fc) s.Require().NoError(err) - s.testData = data - s.testDataB64 = base64.StdEncoding.EncodeToString(data) + s.idf = imagedata.NewFactory(f) - // Create test server - s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - for k, vv := range s.header { - for _, v := range vv { - rw.Header().Add(k, v) - } - } + s.testServer, _ = testutil.NewLazySuiteTestServer( + s, + func(srv *testutil.TestServer) error { + srv.SetHeaders( + httpheaders.ContentType, "image/jpeg", + httpheaders.ContentLength, strconv.Itoa(len(s.testData)), + ).SetBody(s.testData) - data := s.data - if data == nil { - data = s.testData - } - - rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data))) - rw.WriteHeader(s.status) - rw.Write(data) - })) + return nil + }, + ) } -func (s *ImageProviderTestSuite) TearDownSuite() { - s.server.Close() -} - -func (s *ImageProviderTestSuite) SetupTest() { - s.status = http.StatusOK - s.data = nil - s.header = http.Header{} - s.header.Set(httpheaders.ContentType, "image/jpeg") +func (s *ImageProviderTestSuite) SetupSubTest() { + // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest + s.ResetLazyObjects() } // Helper function to read data from ImageData @@ -114,7 +93,7 @@ func (s *ImageProviderTestSuite) TestNewProvider() { }, { name: "URL", - config: &StaticConfig{URL: s.server.URL}, + config: &StaticConfig{URL: s.testServer().URL()}, validateFunc: func(provider Provider) { s.Equal(s.testData, s.readImageData(provider)) }, @@ -149,10 +128,12 @@ func (s *ImageProviderTestSuite) TestNewProvider() { }, { name: "HeadersPassedThrough", - config: &StaticConfig{URL: s.server.URL}, + config: &StaticConfig{URL: s.testServer().URL()}, setupFunc: func() { - s.header.Set("X-Custom-Header", "test-value") - s.header.Set(httpheaders.CacheControl, "max-age=3600") + s.testServer().SetHeaders( + "X-Custom-Header", "test-value", + httpheaders.CacheControl, "max-age=3600", + ) }, validateFunc: func(provider Provider) { imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{}) @@ -167,19 +148,13 @@ func (s *ImageProviderTestSuite) TestNewProvider() { }, } - fc := fetcher.NewDefaultConfig() - f, err := fetcher.New(&fc) - s.Require().NoError(err) - - idf := imagedata.NewFactory(f) - for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { if tt.setupFunc != nil { tt.setupFunc() } - provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", idf) + provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", s.idf) if tt.expectError { s.Require().Error(err) diff --git a/config/config.go b/config/config.go index bf0071d9..0907e55c 100644 --- a/config/config.go +++ b/config/config.go @@ -58,7 +58,7 @@ var ( PngUnlimited bool SvgUnlimited bool MaxResultDimension int - AllowedProcessiongOptions []string + AllowedProcessingOptions []string AllowSecurityOptions bool JpegProgressive bool @@ -267,7 +267,7 @@ func Reset() { PngUnlimited = false SvgUnlimited = false MaxResultDimension = 0 - AllowedProcessiongOptions = make([]string, 0) + AllowedProcessingOptions = make([]string, 0) AllowSecurityOptions = false JpegProgressive = false @@ -502,7 +502,7 @@ func Configure() error { configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED") configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION") - configurators.StringSlice(&AllowedProcessiongOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS") + configurators.StringSlice(&AllowedProcessingOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS") configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS") diff --git a/fetcher/errors.go b/fetcher/errors.go index b11f2fb1..953c1385 100644 --- a/fetcher/errors.go +++ b/fetcher/errors.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" + "io" "net/http" + "github.com/imgproxy/imgproxy/v3/fetcher/transport/generichttp" "github.com/imgproxy/imgproxy/v3/ierrors" - "github.com/imgproxy/imgproxy/v3/security" ) const msgSourceImageIsUnreachable = "Source image is unreachable" @@ -157,13 +158,21 @@ func (e NotModifiedError) Headers() http.Header { func WrapError(err error) error { isTimeout := false - var secArrdErr security.SourceAddressError + var secArrdErr generichttp.SourceAddressError switch { case errors.Is(err, context.DeadlineExceeded): isTimeout = true case errors.Is(err, context.Canceled): return newImageRequestCanceledError(err) + case err == io.ErrUnexpectedEOF: + return ierrors.Wrap( + newImageRequestError(err), + 1, + ierrors.WithPublicMessage("source image is corrupted"), + ierrors.WithShouldReport(false), + ierrors.WithStatusCode(http.StatusUnprocessableEntity), + ) case errors.As(err, &secArrdErr): return ierrors.Wrap( err, diff --git a/fetcher/transport/generichttp/config.go b/fetcher/transport/generichttp/config.go index aa44b20f..ab29a40d 100644 --- a/fetcher/transport/generichttp/config.go +++ b/fetcher/transport/generichttp/config.go @@ -10,15 +10,21 @@ import ( // Config holds the configuration for the generic HTTP transport type Config struct { - ClientKeepAliveTimeout time.Duration - IgnoreSslVerification bool + ClientKeepAliveTimeout time.Duration + IgnoreSslVerification bool + AllowLoopbackSourceAddresses bool + AllowLinkLocalSourceAddresses bool + AllowPrivateSourceAddresses bool } // NewDefaultConfig returns a new default configuration for the generic HTTP transport func NewDefaultConfig() Config { return Config{ - ClientKeepAliveTimeout: 90 * time.Second, - IgnoreSslVerification: false, + ClientKeepAliveTimeout: 90 * time.Second, + IgnoreSslVerification: false, + AllowLoopbackSourceAddresses: false, + AllowLinkLocalSourceAddresses: false, + AllowPrivateSourceAddresses: true, } } @@ -28,6 +34,9 @@ func LoadConfigFromEnv(c *Config) (*Config, error) { c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second c.IgnoreSslVerification = config.IgnoreSslVerification + c.AllowLinkLocalSourceAddresses = config.AllowLinkLocalSourceAddresses + c.AllowLoopbackSourceAddresses = config.AllowLoopbackSourceAddresses + c.AllowPrivateSourceAddresses = config.AllowPrivateSourceAddresses return c, nil } diff --git a/fetcher/transport/generichttp/errors.go b/fetcher/transport/generichttp/errors.go new file mode 100644 index 00000000..8ca90c4f --- /dev/null +++ b/fetcher/transport/generichttp/errors.go @@ -0,0 +1,23 @@ +package generichttp + +import ( + "net/http" + + "github.com/imgproxy/imgproxy/v3/ierrors" +) + +type ( + SourceAddressError string +) + +func newSourceAddressError(msg string) error { + return ierrors.Wrap( + SourceAddressError(msg), + 1, + ierrors.WithStatusCode(http.StatusNotFound), + ierrors.WithPublicMessage("Invalid source URL"), + ierrors.WithShouldReport(false), + ) +} + +func (e SourceAddressError) Error() string { return string(e) } diff --git a/fetcher/transport/generichttp/generic_http.go b/fetcher/transport/generichttp/generic_http.go index 5c37142f..bafc6e54 100644 --- a/fetcher/transport/generichttp/generic_http.go +++ b/fetcher/transport/generichttp/generic_http.go @@ -3,12 +3,12 @@ package generichttp import ( "crypto/tls" + "fmt" "net" "net/http" "syscall" "time" - "github.com/imgproxy/imgproxy/v3/security" "golang.org/x/net/http2" ) @@ -25,7 +25,7 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) { if verifyNetworks { dialer.Control = func(network, address string, c syscall.RawConn) error { - return security.VerifySourceNetwork(address) + return verifySourceNetwork(address, config) } } @@ -66,3 +66,29 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) { return transport, nil } + +func verifySourceNetwork(addr string, config *Config) error { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + + ip := net.ParseIP(host) + if ip == nil { + return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr)) + } + + if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) { + return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr)) + } + + if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) { + return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr)) + } + + if !config.AllowPrivateSourceAddresses && ip.IsPrivate() { + return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr)) + } + + return nil +} diff --git a/security/source_test.go b/fetcher/transport/generichttp/generic_http_test.go similarity index 82% rename from security/source_test.go rename to fetcher/transport/generichttp/generic_http_test.go index 4095874f..52b4c471 100644 --- a/security/source_test.go +++ b/fetcher/transport/generichttp/generic_http_test.go @@ -1,9 +1,8 @@ -package security +package generichttp import ( "testing" - "github.com/imgproxy/imgproxy/v3/config" "github.com/stretchr/testify/require" ) @@ -100,24 +99,14 @@ func TestVerifySourceNetwork(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Backup original config - originalLoopback := config.AllowLoopbackSourceAddresses - originalLinkLocal := config.AllowLinkLocalSourceAddresses - originalPrivate := config.AllowPrivateSourceAddresses - - // Restore original config after test - defer func() { - config.AllowLoopbackSourceAddresses = originalLoopback - config.AllowLinkLocalSourceAddresses = originalLinkLocal - config.AllowPrivateSourceAddresses = originalPrivate - }() + config := NewDefaultConfig() // Override config for the test config.AllowLoopbackSourceAddresses = tc.allowLoopback config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal config.AllowPrivateSourceAddresses = tc.allowPrivate - err := VerifySourceNetwork(tc.addr) + err := verifySourceNetwork(tc.addr, &config) if tc.expectErr { require.Error(t, err) diff --git a/handlers/processing/handler_test.go b/handlers/processing/handler_test.go deleted file mode 100644 index ce50a100..00000000 --- a/handlers/processing/handler_test.go +++ /dev/null @@ -1 +0,0 @@ -package processing diff --git a/handlers/stream/handler_test.go b/handlers/stream/handler_test.go index 66c0d6c3..16a2b230 100644 --- a/handlers/stream/handler_test.go +++ b/handlers/stream/handler_test.go @@ -6,7 +6,6 @@ import ( "net/http" "net/http/httptest" "os" - "path/filepath" "strconv" "testing" "time" @@ -22,23 +21,24 @@ import ( "github.com/imgproxy/imgproxy/v3/testutil" ) -const ( - testDataPath = "../../testdata" -) - type HandlerTestSuite struct { testutil.LazySuite + testData *testutil.TestDataProvider + rwConf testutil.LazyObj[*responsewriter.Config] rwFactory testutil.LazyObj[*responsewriter.Factory] config testutil.LazyObj[*Config] handler testutil.LazyObj[*Handler] + + testServer testutil.LazyTestServer } func (s *HandlerTestSuite) SetupSuite() { config.Reset() - config.AllowLoopbackSourceAddresses = true + + s.testData = testutil.NewTestDataProvider(s.T) s.rwConf, _ = testutil.NewLazySuiteObj( s, @@ -67,6 +67,7 @@ func (s *HandlerTestSuite) SetupSuite() { s, func() (*Handler, error) { fc := fetcher.NewDefaultConfig() + fc.Transport.HTTP.AllowLoopbackSourceAddresses = true fetcher, err := fetcher.New(&fc) s.Require().NoError(err) @@ -75,36 +76,27 @@ func (s *HandlerTestSuite) SetupSuite() { }, ) + s.testServer, _ = testutil.NewLazySuiteTestServer(s) + // Silence logs during tests logrus.SetOutput(io.Discard) } func (s *HandlerTestSuite) TearDownSuite() { - config.Reset() logrus.SetOutput(os.Stdout) } -func (s *HandlerTestSuite) SetupTest() { - config.Reset() - config.AllowLoopbackSourceAddresses = true -} - func (s *HandlerTestSuite) SetupSubTest() { // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest s.ResetLazyObjects() } -func (s *HandlerTestSuite) readTestFile(name string) []byte { - data, err := os.ReadFile(filepath.Join(testDataPath, name)) - s.Require().NoError(err) - return data -} - func (s *HandlerTestSuite) execute( imageURL string, header http.Header, po *options.ProcessingOptions, -) *httptest.ResponseRecorder { +) *http.Response { + imageURL = s.testServer().URL() + imageURL req := httptest.NewRequest("GET", "/", nil) httpheaders.CopyAll(header, req.Header, true) @@ -115,51 +107,42 @@ func (s *HandlerTestSuite) execute( err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww) s.Require().NoError(err) - return rw + return rw.Result() } // TestHandlerBasicRequest checks basic streaming request func (s *HandlerTestSuite) TestHandlerBasicRequest() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data) - rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) + res := s.execute("", nil, &options.ProcessingOptions{}) - res := rw.Result() s.Require().Equal(200, res.StatusCode) s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType)) // Verify we get the original image data - actual := rw.Body.Bytes() + actual, err := io.ReadAll(res.Body) + s.Require().NoError(err) s.Require().Equal(data, actual) } // TestHandlerResponseHeadersPassthrough checks that original response headers are // passed through to the client func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") contentLength := len(data) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(httpheaders.ContentType, "image/png") - w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength)) - w.Header().Set(httpheaders.AcceptRanges, "bytes") - w.Header().Set(httpheaders.Etag, "etag") - w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer().SetHeaders( + httpheaders.ContentType, "image/png", + httpheaders.ContentLength, strconv.Itoa(contentLength), + httpheaders.AcceptRanges, "bytes", + httpheaders.Etag, "etag", + httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT", + ).SetBody(data) - rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) + res := s.execute("", nil, &options.ProcessingOptions{}) - res := rw.Result() s.Require().Equal(200, res.StatusCode) s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType)) s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength)) @@ -172,42 +155,34 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() { // to the server func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() { etag := `"test-etag-123"` - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify that If-None-Match header is passed through - s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch)) - s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding)) - s.Equal("bytes=*", r.Header.Get(httpheaders.Range)) - - w.Header().Set(httpheaders.Etag, etag) - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer(). + SetBody(data). + SetHeaders(httpheaders.Etag, etag). + SetHook(func(r *http.Request, rw http.ResponseWriter) { + // Verify that If-None-Match header is passed through + s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch)) + s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding)) + s.Equal("bytes=*", r.Header.Get(httpheaders.Range)) + }) h := make(http.Header) h.Set(httpheaders.IfNoneMatch, etag) h.Set(httpheaders.AcceptEncoding, "gzip") h.Set(httpheaders.Range, "bytes=*") - rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) + res := s.execute("", h, &options.ProcessingOptions{}) - res := rw.Result() s.Require().Equal(200, res.StatusCode) s.Require().Equal(etag, res.Header.Get(httpheaders.Etag)) } // TestHandlerContentDisposition checks that Content-Disposition header is set correctly func (s *HandlerTestSuite) TestHandlerContentDisposition() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data) po := &options.ProcessingOptions{ Filename: "custom_name", @@ -215,10 +190,8 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() { } // Use a URL with a .png extension to help content disposition logic - imageURL := ts.URL + "/test.png" - rw := s.execute(imageURL, nil, po) + res := s.execute("/test.png", nil, po) - res := rw.Result() s.Require().Equal(200, res.StatusCode) s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png") s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment") @@ -229,7 +202,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { type testCase struct { name string cacheControlPassthrough bool - setupOriginHeaders func(http.ResponseWriter) + setupOriginHeaders func() timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now expectedStatusCode int validate func(*testing.T, *http.Response) @@ -250,8 +223,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "Passthrough", cacheControlPassthrough: true, - setupOriginHeaders: func(w http.ResponseWriter) { - w.Header().Set(httpheaders.CacheControl, "max-age=3600, public") + setupOriginHeaders: func() { + s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public") }, timestampOffset: nil, expectedStatusCode: 200, @@ -263,8 +236,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "ExpiresPassthrough", cacheControlPassthrough: true, - setupOriginHeaders: func(w http.ResponseWriter) { - w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) + setupOriginHeaders: func() { + s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) }, timestampOffset: nil, expectedStatusCode: 200, @@ -278,8 +251,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "PassthroughDisabled", cacheControlPassthrough: false, - setupOriginHeaders: func(w http.ResponseWriter) { - w.Header().Set(httpheaders.CacheControl, "max-age=3600, public") + setupOriginHeaders: func() { + s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public") }, timestampOffset: nil, expectedStatusCode: 200, @@ -291,7 +264,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "WithProcessingOptionsExpires", cacheControlPassthrough: false, - setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers timestampOffset: &oneHour, expectedStatusCode: 200, validate: func(t *testing.T, res *http.Response) { @@ -303,9 +275,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "ProcessingOptionsOverridesOrigin", cacheControlPassthrough: true, - setupOriginHeaders: func(w http.ResponseWriter) { + setupOriginHeaders: func() { // Origin has a longer cache time - w.Header().Set(httpheaders.CacheControl, "max-age=7200, public") + s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public") }, timestampOffset: &thirtyMinutes, expectedStatusCode: 200, @@ -318,10 +290,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "BothHeadersPassthroughEnabled", cacheControlPassthrough: true, - setupOriginHeaders: func(w http.ResponseWriter) { + setupOriginHeaders: func() { // Origin has both Cache-Control and Expires headers - w.Header().Set(httpheaders.CacheControl, "max-age=1800, public") - w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) + s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public") + s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) }, timestampOffset: nil, expectedStatusCode: 200, @@ -336,10 +308,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "ProcessingOptionsOverridesBothOriginHeaders", cacheControlPassthrough: true, - setupOriginHeaders: func(w http.ResponseWriter) { + setupOriginHeaders: func() { // Origin has both Cache-Control and Expires headers with longer cache times - w.Header().Set(httpheaders.CacheControl, "max-age=7200, public") - w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat)) + s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public") + s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat)) }, timestampOffset: &fortyFiveMinutes, // Shorter than origin headers expectedStatusCode: 200, @@ -352,7 +324,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { { name: "NoOriginHeaders", cacheControlPassthrough: false, - setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers timestampOffset: nil, expectedStatusCode: 200, validate: func(t *testing.T, res *http.Response) { @@ -363,15 +334,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { for _, tc := range testCases { s.Run(tc.name, func() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tc.setupOriginHeaders(w) - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + if tc.setupOriginHeaders != nil { + tc.setupOriginHeaders() + } + + s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data) s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough s.rwConf().DefaultTTL = 4242 @@ -383,9 +352,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { po.Expires = &expires } - rw := s.execute(ts.URL, nil, po) - - res := rw.Result() + res := s.execute("", nil, po) s.Require().Equal(tc.expectedStatusCode, res.StatusCode) tc.validate(s.T(), res) }) @@ -405,85 +372,64 @@ func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration { // TestHandlerSecurityHeaders tests the security headers set by the streaming service. func (s *HandlerTestSuite) TestHandlerSecurityHeaders() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data) - rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) + res := s.execute("", nil, &options.ProcessingOptions{}) - res := rw.Result() - s.Require().Equal(200, res.StatusCode) + s.Require().Equal(http.StatusOK, res.StatusCode) s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy)) } // TestHandlerErrorResponse tests the error responses from the streaming service. func (s *HandlerTestSuite) TestHandlerErrorResponse() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - w.Write([]byte("Not Found")) - })) - defer ts.Close() + s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found")) - rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) + res := s.execute("", nil, &options.ProcessingOptions{}) - res := rw.Result() - s.Require().Equal(404, res.StatusCode) + s.Require().Equal(http.StatusNotFound, res.StatusCode) } // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { s.config().CookiePassthrough = true - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify cookies are passed through - cookie, cerr := r.Cookie("test_cookie") - if cerr == nil { - s.Equal("test_value", cookie.Value) - } - - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer(). + SetHeaders(httpheaders.Cookie, "test_cookie=test_value"). + SetHook(func(r *http.Request, rw http.ResponseWriter) { + // Verify cookies are passed through + cookie, cerr := r.Cookie("test_cookie") + if cerr == nil { + s.Equal("test_value", cookie.Value) + } + }).SetBody(data) h := make(http.Header) h.Set(httpheaders.Cookie, "test_cookie=test_value") - rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) + res := s.execute("", h, &options.ProcessingOptions{}) - res := rw.Result() s.Require().Equal(200, res.StatusCode) } // TestHandlerCanonicalHeader tests that the canonical header is set correctly func (s *HandlerTestSuite) TestHandlerCanonicalHeader() { - data := s.readTestFile("test1.png") + data := s.testData.Read("test1.png") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(httpheaders.ContentType, "image/png") - w.WriteHeader(200) - w.Write(data) - })) - defer ts.Close() + s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data) for _, sc := range []bool{true, false} { s.rwConf().SetCanonicalHeader = sc - rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) + res := s.execute("", nil, &options.ProcessingOptions{}) - res := rw.Result() s.Require().Equal(200, res.StatusCode) if sc { - s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL)) + s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL())) } else { s.Require().Empty(res.Header.Get(httpheaders.Link)) } diff --git a/imagedata/image_data_test.go b/imagedata/image_data_test.go index 9534aae3..46200f47 100644 --- a/imagedata/image_data_test.go +++ b/imagedata/image_data_test.go @@ -6,17 +6,13 @@ import ( "context" "encoding/base64" "fmt" - "io" "net" "net/http" - "net/http/httptest" - "os" "strconv" "testing" "github.com/stretchr/testify/suite" - "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/fetcher" "github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/ierrors" @@ -25,88 +21,70 @@ import ( ) type ImageDataTestSuite struct { - suite.Suite + testutil.LazySuite - server *httptest.Server + fetcherCfg testutil.LazyObj[*fetcher.Config] + factory testutil.LazyObj[*Factory] + testServer testutil.LazyTestServer - status int - data []byte - header http.Header - check func(*http.Request) - factory *Factory - - defaultData []byte + data []byte } func (s *ImageDataTestSuite) SetupSuite() { - config.Reset() - config.ClientKeepAliveTimeout = 0 + s.data = testutil.NewTestDataProvider(s.T).Read("test1.jpg") - f, err := os.Open("../testdata/test1.jpg") - s.Require().NoError(err) - defer f.Close() + s.fetcherCfg, _ = testutil.NewLazySuiteObj( + s, + func() (*fetcher.Config, error) { + c := fetcher.NewDefaultConfig() + c.Transport.HTTP.AllowLoopbackSourceAddresses = true + c.Transport.HTTP.ClientKeepAliveTimeout = 0 - data, err := io.ReadAll(f) - s.Require().NoError(err) + return &c, nil + }, + ) - s.defaultData = data + s.factory, _ = testutil.NewLazySuiteObj( + s, + func() (*Factory, error) { + fetcher, err := fetcher.New(s.fetcherCfg()) + if err != nil { + return nil, err + } - s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if s.check != nil { - s.check(r) - } + return NewFactory(fetcher), nil + }, + ) - httpheaders.CopyAll(s.header, rw.Header(), true) + s.testServer, _ = testutil.NewLazySuiteTestServer( + s, + func(srv *testutil.TestServer) error { + // Default headers and body for 200 OK response + srv.SetHeaders( + httpheaders.ContentType, "image/jpeg", + httpheaders.ContentLength, strconv.Itoa(len(s.data)), + ).SetBody(s.data) - data := s.data - if data == nil { - data = s.defaultData - } - - rw.Header().Set("Content-Length", strconv.Itoa(len(data))) - - rw.WriteHeader(s.status) - rw.Write(data) - })) - - c, err := fetcher.LoadConfigFromEnv(nil) - s.Require().NoError(err) - - fetcher, err := fetcher.New(c) - s.Require().NoError(err) - - s.factory = NewFactory(fetcher) + return nil + }, + ) } -func (s *ImageDataTestSuite) TearDownSuite() { - s.server.Close() -} - -func (s *ImageDataTestSuite) SetupTest() { - config.Reset() - config.AllowLoopbackSourceAddresses = true - - s.status = http.StatusOK - s.data = nil - s.check = nil - - s.header = http.Header{} - s.header.Set("Content-Type", "image/jpeg") - +func (s *ImageDataTestSuite) SetupSubTest() { + // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest + s.ResetLazyObjects() } func (s *ImageDataTestSuite) TestDownloadStatusOK() { - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().NoError(err) s.Require().NotNil(imgdata) - s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) + s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader())) s.Require().Equal(imagetype.JPEG, imgdata.Format()) } func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() { - s.status = http.StatusPartialContent - testCases := []struct { name string contentRange string @@ -114,17 +92,17 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() { }{ { name: "Full Content-Range", - contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-1, len(s.defaultData)), + contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-1, len(s.data)), expectErr: false, }, { name: "Partial Content-Range, early end", - contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-2, len(s.defaultData)), + contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-2, len(s.data)), expectErr: true, }, { name: "Partial Content-Range, late start", - contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.defaultData)-1, len(s.defaultData)), + contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.data)-1, len(s.data)), expectErr: true, }, { @@ -139,39 +117,41 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() { }, { name: "Unknown Content-Range range", - contentRange: fmt.Sprintf("bytes */%d", len(s.defaultData)), + contentRange: fmt.Sprintf("bytes */%d", len(s.data)), expectErr: true, }, { name: "Unknown Content-Range size, full range", - contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-1), + contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-1), expectErr: false, }, { name: "Unknown Content-Range size, early end", - contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-2), + contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-2), expectErr: true, }, { name: "Unknown Content-Range size, late start", - contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.defaultData)-1), + contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.data)-1), expectErr: true, }, } for _, tc := range testCases { s.Run(tc.name, func() { - s.header.Set("Content-Range", tc.contentRange) + s.testServer(). + SetHeaders(httpheaders.ContentRange, tc.contentRange). + SetStatusCode(http.StatusPartialContent) - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) if tc.expectErr { s.Require().Error(err) - s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) + s.Require().Equal(http.StatusNotFound, ierrors.Wrap(err, 0).StatusCode()) } else { s.Require().NoError(err) s.Require().NotNil(imgdata) - s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) + s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader())) s.Require().Equal(imagetype.JPEG, imgdata.Format()) } }) @@ -179,11 +159,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() { } func (s *ImageDataTestSuite) TestDownloadStatusNotFound() { - s.status = http.StatusNotFound - s.data = []byte("Not Found") - s.header.Set("Content-Type", "text/plain") + s.testServer(). + SetStatusCode(http.StatusNotFound). + SetBody([]byte("Not Found")). + SetHeaders(httpheaders.ContentType, "text/plain") - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().Error(err) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) @@ -191,11 +172,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusNotFound() { } func (s *ImageDataTestSuite) TestDownloadStatusForbidden() { - s.status = http.StatusForbidden - s.data = []byte("Forbidden") - s.header.Set("Content-Type", "text/plain") + s.testServer(). + SetStatusCode(http.StatusForbidden). + SetBody([]byte("Forbidden")). + SetHeaders(httpheaders.ContentType, "text/plain") - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().Error(err) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) @@ -203,11 +185,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusForbidden() { } func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() { - s.status = http.StatusInternalServerError - s.data = []byte("Internal Server Error") - s.header.Set("Content-Type", "text/plain") + s.testServer(). + SetStatusCode(http.StatusInternalServerError). + SetBody([]byte("Internal Server Error")). + SetHeaders(httpheaders.ContentType, "text/plain") - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().Error(err) s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode()) @@ -221,7 +204,7 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() { serverURL := fmt.Sprintf("http://%s", l.Addr().String()) - imgdata, _, err := s.factory.DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{}) s.Require().Error(err) s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode()) @@ -229,19 +212,19 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() { } func (s *ImageDataTestSuite) TestDownloadInvalidImage() { - s.data = []byte("invalid") + s.testServer().SetBody([]byte("invalid")) - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().Error(err) - s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode()) + s.Require().Equal(http.StatusUnprocessableEntity, ierrors.Wrap(err, 0).StatusCode()) s.Require().Nil(imgdata) } func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() { - config.AllowLoopbackSourceAddresses = false + s.fetcherCfg().Transport.HTTP.AllowLoopbackSourceAddresses = false - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().Error(err) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) @@ -249,11 +232,10 @@ func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() { } func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() { - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{ + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{ MaxSrcFileSize: 1, }) - fmt.Println(err) s.Require().Error(err) s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode()) s.Require().Nil(imgdata) @@ -263,39 +245,43 @@ func (s *ImageDataTestSuite) TestDownloadGzip() { buf := new(bytes.Buffer) enc := gzip.NewWriter(buf) - _, err := enc.Write(s.defaultData) + _, err := enc.Write(s.data) s.Require().NoError(err) err = enc.Close() s.Require().NoError(err) - s.data = buf.Bytes() - s.header.Set("Content-Encoding", "gzip") + s.testServer(). + SetBody(buf.Bytes()). + SetHeaders( + httpheaders.ContentEncoding, "gzip", + httpheaders.ContentLength, strconv.Itoa(buf.Len()), // Update Content-Length + ) - imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) + imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{}) s.Require().NoError(err) s.Require().NotNil(imgdata) - s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) + s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader())) s.Require().Equal(imagetype.JPEG, imgdata.Format()) } func (s *ImageDataTestSuite) TestFromFile() { - imgdata, err := s.factory.NewFromPath("../testdata/test1.jpg") + imgdata, err := s.factory().NewFromPath("../testdata/test1.jpg") s.Require().NoError(err) s.Require().NotNil(imgdata) - s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) + s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader())) s.Require().Equal(imagetype.JPEG, imgdata.Format()) } func (s *ImageDataTestSuite) TestFromBase64() { - b64 := base64.StdEncoding.EncodeToString(s.defaultData) + b64 := base64.StdEncoding.EncodeToString(s.data) - imgdata, err := s.factory.NewFromBase64(b64) + imgdata, err := s.factory().NewFromBase64(b64) s.Require().NoError(err) s.Require().NotNil(imgdata) - s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) + s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader())) s.Require().Equal(imagetype.JPEG, imgdata.Format()) } diff --git a/integration/processing_handler_test.go b/integration/processing_handler_test.go index 56d60c5b..8fa49ea3 100644 --- a/integration/processing_handler_test.go +++ b/integration/processing_handler_test.go @@ -27,10 +27,7 @@ type ProcessingHandlerTestSuite struct { func (s *ProcessingHandlerTestSuite) SetupTest() { config.Reset() // We reset config only at the start of each test - - // NOTE: This must be moved to security config - config.AllowLoopbackSourceAddresses = true - // NOTE: end note + s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true } func (s *ProcessingHandlerTestSuite) SetupSubTest() { @@ -142,13 +139,13 @@ func (s *ProcessingHandlerTestSuite) TestSourceNetworkValidation() { // We wrap this in a subtest to reset s.router() s.Run("AllowLoopbackSourceAddressesTrue", func() { - config.AllowLoopbackSourceAddresses = true + s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true res := s.GET(url) s.Require().Equal(http.StatusOK, res.StatusCode) }) s.Run("AllowLoopbackSourceAddressesFalse", func() { - config.AllowLoopbackSourceAddresses = false + s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = false res := s.GET(url) s.Require().Equal(http.StatusNotFound, res.StatusCode) }) @@ -256,7 +253,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() { } func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() { - config.CacheControlPassthrough = true + s.Config().Server.ResponseWriter.CacheControlPassthrough = true ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat)) @@ -290,7 +287,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughDisabled() { } func (s *ProcessingHandlerTestSuite) TestETagDisabled() { - config.ETagEnabled = false + s.Config().Handlers.Processing.ETagEnabled = false res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png") @@ -299,7 +296,7 @@ func (s *ProcessingHandlerTestSuite) TestETagDisabled() { } func (s *ProcessingHandlerTestSuite) TestETagDataMatch() { - config.ETagEnabled = true + s.Config().Handlers.Processing.ETagEnabled = true etag := `"loremipsumdolor"` @@ -321,7 +318,8 @@ func (s *ProcessingHandlerTestSuite) TestETagDataMatch() { } func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() { - config.LastModifiedEnabled = true + s.Config().Handlers.Processing.LastModifiedEnabled = true + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") rw.WriteHeader(200) @@ -335,7 +333,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() { } func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() { - config.LastModifiedEnabled = false + s.Config().Handlers.Processing.LastModifiedEnabled = false ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") rw.WriteHeader(200) @@ -349,7 +347,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() { } func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() { - config.LastModifiedEnabled = false + s.Config().Handlers.Processing.LastModifiedEnabled = false data := s.TestData.Read("test1.png") lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -368,7 +366,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedD } func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() { - config.LastModifiedEnabled = true + s.Config().Handlers.Processing.LastModifiedEnabled = true lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { modifiedSince := r.Header.Get(httpheaders.IfModifiedSince) @@ -386,7 +384,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedE func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() { data := s.TestData.Read("test1.png") - config.LastModifiedEnabled = false + s.Config().Handlers.Processing.LastModifiedEnabled = false ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { modifiedSince := r.Header.Get(httpheaders.IfModifiedSince) s.Empty(modifiedSince) diff --git a/options/processing_options.go b/options/processing_options.go index af4010a5..6406ff8a 100644 --- a/options/processing_options.go +++ b/options/processing_options.go @@ -1082,10 +1082,10 @@ func applyURLOption(po *ProcessingOptions, name string, args []string, usedPrese } func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error { - allowAll = allowAll || len(config.AllowedProcessiongOptions) == 0 + allowAll = allowAll || len(config.AllowedProcessingOptions) == 0 for _, opt := range options { - if !allowAll && !slices.Contains(config.AllowedProcessiongOptions, opt.Name) { + if !allowAll && !slices.Contains(config.AllowedProcessingOptions, opt.Name) { return newForbiddenOptionError("processing", opt.Name) } diff --git a/options/processing_options_test.go b/options/processing_options_test.go index 039923e9..6d38f6ec 100644 --- a/options/processing_options_test.go +++ b/options/processing_options_test.go @@ -646,7 +646,7 @@ func (s *ProcessingOptionsTestSuite) TestParseBase64URLOnlyPresets() { } func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() { - config.AllowedProcessiongOptions = []string{"w", "h", "pr"} + config.AllowedProcessingOptions = []string{"w", "h", "pr"} presets["test1"] = urlOptions{ urlOption{Name: "blur", Args: []string{"0.2"}}, diff --git a/security/errors.go b/security/errors.go index 86b19ddc..3754c26f 100644 --- a/security/errors.go +++ b/security/errors.go @@ -13,7 +13,6 @@ type ( ImageResolutionError string SecurityOptionsError struct{} SourceURLError string - SourceAddressError string ) func newSignatureError(msg string) error { @@ -75,15 +74,3 @@ func newSourceURLError(imageURL string) error { } func (e SourceURLError) Error() string { return string(e) } - -func newSourceAddressError(msg string) error { - return ierrors.Wrap( - SourceAddressError(msg), - 1, - ierrors.WithStatusCode(http.StatusNotFound), - ierrors.WithPublicMessage("Invalid source URL"), - ierrors.WithShouldReport(false), - ) -} - -func (e SourceAddressError) Error() string { return string(e) } diff --git a/security/source.go b/security/source.go index bc09b057..9472fe9f 100644 --- a/security/source.go +++ b/security/source.go @@ -1,9 +1,6 @@ package security import ( - "fmt" - "net" - "github.com/imgproxy/imgproxy/v3/config" ) @@ -20,29 +17,3 @@ func VerifySourceURL(imageURL string) error { return newSourceURLError(imageURL) } - -func VerifySourceNetwork(addr string) error { - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - - ip := net.ParseIP(host) - if ip == nil { - return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr)) - } - - if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) { - return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr)) - } - - if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) { - return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr)) - } - - if !config.AllowPrivateSourceAddresses && ip.IsPrivate() { - return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr)) - } - - return nil -} diff --git a/testutil/lasy_suite.go b/testutil/lasy_suite.go index 9e41441a..18fe6049 100644 --- a/testutil/lasy_suite.go +++ b/testutil/lasy_suite.go @@ -51,7 +51,7 @@ func NewLazySuiteObj[T any]( // Get the [LazySuite] instance lazy := s.Lazy() // Create the [LazyObj] instance - obj, cancel := NewLazyObj(lazy, newFn, dropFn...) + obj, cancel := newLazyObj(lazy, newFn, dropFn...) // Add cleanup function to the resets list lazy.resets = append(lazy.resets, cancel) diff --git a/testutil/lazy_obj.go b/testutil/lazy_obj.go index de6d0cbb..d5721472 100644 --- a/testutil/lazy_obj.go +++ b/testutil/lazy_obj.go @@ -23,10 +23,10 @@ type LazyObjNew[T any] func() (T, error) // If the object was not yet initialized, the callback is not called. type LazyObjDrop[T any] func(T) error -// NewLazyObj creates a new [LazyObj] that initializes the object on the first call. +// newLazyObj creates a new [LazyObj] that initializes the object on the first call. // It returns a function that can be called to get the object and a cancel function // that can be called to reset the object. -func NewLazyObj[T any]( +func newLazyObj[T any]( s LazyObjT, newFn LazyObjNew[T], dropFn ...LazyObjDrop[T], diff --git a/testutil/test_data_provider.go b/testutil/test_data_provider.go index 45ad51e1..ac0bf785 100644 --- a/testutil/test_data_provider.go +++ b/testutil/test_data_provider.go @@ -26,9 +26,6 @@ type TestDataProvider struct { // New creates a new TestDataProvider func NewTestDataProvider(t TestDataProviderT) *TestDataProvider { - // if h, ok := t.(interface{ Helper() }); ok { - // h.Helper() - // } t().Helper() path, err := findProjectRoot() diff --git a/testutil/test_server.go b/testutil/test_server.go new file mode 100644 index 00000000..a74ca5da --- /dev/null +++ b/testutil/test_server.go @@ -0,0 +1,122 @@ +package testutil + +import ( + "context" + "net/http" + "net/http/httptest" + + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/stretchr/testify/require" +) + +// TestServerHookFunc is a function type for in-request hooks +type TestServerHookFunc func(r *http.Request, rw http.ResponseWriter) + +// Sugar alias +type LazyTestServer = LazyObj[*TestServer] + +// TestServer is a syntax sugar wrapper over httptest.Server +type TestServer struct { + testServer *httptest.Server + status int + data []byte + header http.Header + hook TestServerHookFunc +} + +// NewLazySuiteTestServer creates a lazy TestServer object for use in test suites +func NewLazySuiteTestServer( + l LazySuiteFrom, + init ...func(*TestServer) error, +) (LazyObj[*TestServer], context.CancelFunc) { + return NewLazySuiteObj( + l, + func() (*TestServer, error) { + s := NewTestServer() + + if len(init) > 0 { + for _, fn := range init { + if fn == nil { + continue + } + + err := fn(s) + require.NoError(l.Lazy().T(), err, "Failed to reset test server") + } + } + + return s, nil + }, + func(s *TestServer) error { + s.Close() + return nil + }, + ) +} + +// New creates and starts new http.TestServer +func NewTestServer() *TestServer { + ts := &TestServer{ + status: http.StatusOK, + header: make(http.Header), + data: nil, + hook: nil, + } + + return ts.start() +} + +// SetStatusCode sets the status code that will be returned by the server +func (s *TestServer) SetStatusCode(status int) *TestServer { + s.status = status + return s +} + +// SetBody sets the body that will be returned by the server +func (s *TestServer) SetBody(data []byte) *TestServer { + s.data = data + return s +} + +// WithHeader adds headers that will be returned by the server. +// Odd arguments are treated as keys, even arguments as values. +func (s *TestServer) SetHeaders(kv ...string) *TestServer { + for i := 0; i+1 < len(kv); i += 2 { + key := kv[i] + value := kv[i+1] + s.header.Set(key, value) + } + + return s +} + +// SetHook sets a function that will be called on each request. It is called +// after headsers are set, but before status and body are written. +func (s *TestServer) SetHook(f TestServerHookFunc) *TestServer { + s.hook = f + return s +} + +// Start starts the server +func (s *TestServer) start() *TestServer { + s.testServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpheaders.CopyAll(s.header, w.Header(), true) + if s.hook != nil { + s.hook(r, w) + } + w.WriteHeader(s.status) + w.Write(s.data) + })) + + return s +} + +// Close stops the server +func (s *TestServer) Close() { + s.testServer.Close() +} + +// URL returns the server URL +func (s *TestServer) URL() string { + return s.testServer.URL +}