TestServer, AllowNetworks -> http.Transport

This commit is contained in:
Viktor Sokolov
2025-09-11 10:15:31 +02:00
parent 1f6d007948
commit 246ea28864
19 changed files with 425 additions and 388 deletions

View File

@@ -3,77 +3,56 @@ package auximageprovider
import (
"encoding/base64"
"io"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/testutil"
)
type ImageProviderTestSuite struct {
suite.Suite
testutil.LazySuite
server *httptest.Server
testData []byte
testDataB64 string
// Server state
status int
data []byte
header http.Header
testServer testutil.LazyTestServer
idf *imagedata.Factory
}
func (s *ImageProviderTestSuite) SetupSuite() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
s.testData = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
s.testDataB64 = base64.StdEncoding.EncodeToString(s.testData)
// Load test image data
f, err := os.Open("../testdata/test1.jpg")
s.Require().NoError(err)
defer f.Close()
fc := fetcher.NewDefaultConfig()
fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
data, err := io.ReadAll(f)
f, err := fetcher.New(&fc)
s.Require().NoError(err)
s.testData = data
s.testDataB64 = base64.StdEncoding.EncodeToString(data)
s.idf = imagedata.NewFactory(f)
// Create test server
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for k, vv := range s.header {
for _, v := range vv {
rw.Header().Add(k, v)
}
}
s.testServer, _ = testutil.NewLazySuiteTestServer(
s,
func(srv *testutil.TestServer) error {
srv.SetHeaders(
httpheaders.ContentType, "image/jpeg",
httpheaders.ContentLength, strconv.Itoa(len(s.testData)),
).SetBody(s.testData)
data := s.data
if data == nil {
data = s.testData
}
rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data)))
rw.WriteHeader(s.status)
rw.Write(data)
}))
return nil
},
)
}
func (s *ImageProviderTestSuite) TearDownSuite() {
s.server.Close()
}
func (s *ImageProviderTestSuite) SetupTest() {
s.status = http.StatusOK
s.data = nil
s.header = http.Header{}
s.header.Set(httpheaders.ContentType, "image/jpeg")
func (s *ImageProviderTestSuite) SetupSubTest() {
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
s.ResetLazyObjects()
}
// Helper function to read data from ImageData
@@ -114,7 +93,7 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
},
{
name: "URL",
config: &StaticConfig{URL: s.server.URL},
config: &StaticConfig{URL: s.testServer().URL()},
validateFunc: func(provider Provider) {
s.Equal(s.testData, s.readImageData(provider))
},
@@ -149,10 +128,12 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
},
{
name: "HeadersPassedThrough",
config: &StaticConfig{URL: s.server.URL},
config: &StaticConfig{URL: s.testServer().URL()},
setupFunc: func() {
s.header.Set("X-Custom-Header", "test-value")
s.header.Set(httpheaders.CacheControl, "max-age=3600")
s.testServer().SetHeaders(
"X-Custom-Header", "test-value",
httpheaders.CacheControl, "max-age=3600",
)
},
validateFunc: func(provider Provider) {
imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{})
@@ -167,19 +148,13 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
},
}
fc := fetcher.NewDefaultConfig()
f, err := fetcher.New(&fc)
s.Require().NoError(err)
idf := imagedata.NewFactory(f)
for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) {
if tt.setupFunc != nil {
tt.setupFunc()
}
provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", idf)
provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", s.idf)
if tt.expectError {
s.Require().Error(err)

View File

@@ -58,7 +58,7 @@ var (
PngUnlimited bool
SvgUnlimited bool
MaxResultDimension int
AllowedProcessiongOptions []string
AllowedProcessingOptions []string
AllowSecurityOptions bool
JpegProgressive bool
@@ -267,7 +267,7 @@ func Reset() {
PngUnlimited = false
SvgUnlimited = false
MaxResultDimension = 0
AllowedProcessiongOptions = make([]string, 0)
AllowedProcessingOptions = make([]string, 0)
AllowSecurityOptions = false
JpegProgressive = false
@@ -502,7 +502,7 @@ func Configure() error {
configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED")
configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION")
configurators.StringSlice(&AllowedProcessiongOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
configurators.StringSlice(&AllowedProcessingOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS")

View File

@@ -4,10 +4,11 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"github.com/imgproxy/imgproxy/v3/fetcher/transport/generichttp"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/security"
)
const msgSourceImageIsUnreachable = "Source image is unreachable"
@@ -157,13 +158,21 @@ func (e NotModifiedError) Headers() http.Header {
func WrapError(err error) error {
isTimeout := false
var secArrdErr security.SourceAddressError
var secArrdErr generichttp.SourceAddressError
switch {
case errors.Is(err, context.DeadlineExceeded):
isTimeout = true
case errors.Is(err, context.Canceled):
return newImageRequestCanceledError(err)
case err == io.ErrUnexpectedEOF:
return ierrors.Wrap(
newImageRequestError(err),
1,
ierrors.WithPublicMessage("source image is corrupted"),
ierrors.WithShouldReport(false),
ierrors.WithStatusCode(http.StatusUnprocessableEntity),
)
case errors.As(err, &secArrdErr):
return ierrors.Wrap(
err,

View File

@@ -10,15 +10,21 @@ import (
// Config holds the configuration for the generic HTTP transport
type Config struct {
ClientKeepAliveTimeout time.Duration
IgnoreSslVerification bool
ClientKeepAliveTimeout time.Duration
IgnoreSslVerification bool
AllowLoopbackSourceAddresses bool
AllowLinkLocalSourceAddresses bool
AllowPrivateSourceAddresses bool
}
// NewDefaultConfig returns a new default configuration for the generic HTTP transport
func NewDefaultConfig() Config {
return Config{
ClientKeepAliveTimeout: 90 * time.Second,
IgnoreSslVerification: false,
ClientKeepAliveTimeout: 90 * time.Second,
IgnoreSslVerification: false,
AllowLoopbackSourceAddresses: false,
AllowLinkLocalSourceAddresses: false,
AllowPrivateSourceAddresses: true,
}
}
@@ -28,6 +34,9 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
c.IgnoreSslVerification = config.IgnoreSslVerification
c.AllowLinkLocalSourceAddresses = config.AllowLinkLocalSourceAddresses
c.AllowLoopbackSourceAddresses = config.AllowLoopbackSourceAddresses
c.AllowPrivateSourceAddresses = config.AllowPrivateSourceAddresses
return c, nil
}

View File

@@ -0,0 +1,23 @@
package generichttp
import (
"net/http"
"github.com/imgproxy/imgproxy/v3/ierrors"
)
type (
SourceAddressError string
)
func newSourceAddressError(msg string) error {
return ierrors.Wrap(
SourceAddressError(msg),
1,
ierrors.WithStatusCode(http.StatusNotFound),
ierrors.WithPublicMessage("Invalid source URL"),
ierrors.WithShouldReport(false),
)
}
func (e SourceAddressError) Error() string { return string(e) }

View File

@@ -3,12 +3,12 @@ package generichttp
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"syscall"
"time"
"github.com/imgproxy/imgproxy/v3/security"
"golang.org/x/net/http2"
)
@@ -25,7 +25,7 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
if verifyNetworks {
dialer.Control = func(network, address string, c syscall.RawConn) error {
return security.VerifySourceNetwork(address)
return verifySourceNetwork(address, config)
}
}
@@ -66,3 +66,29 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
return transport, nil
}
func verifySourceNetwork(addr string, config *Config) error {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
}
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
}
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
}
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
}
return nil
}

View File

@@ -1,9 +1,8 @@
package security
package generichttp
import (
"testing"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/stretchr/testify/require"
)
@@ -100,24 +99,14 @@ func TestVerifySourceNetwork(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Backup original config
originalLoopback := config.AllowLoopbackSourceAddresses
originalLinkLocal := config.AllowLinkLocalSourceAddresses
originalPrivate := config.AllowPrivateSourceAddresses
// Restore original config after test
defer func() {
config.AllowLoopbackSourceAddresses = originalLoopback
config.AllowLinkLocalSourceAddresses = originalLinkLocal
config.AllowPrivateSourceAddresses = originalPrivate
}()
config := NewDefaultConfig()
// Override config for the test
config.AllowLoopbackSourceAddresses = tc.allowLoopback
config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal
config.AllowPrivateSourceAddresses = tc.allowPrivate
err := VerifySourceNetwork(tc.addr)
err := verifySourceNetwork(tc.addr, &config)
if tc.expectErr {
require.Error(t, err)

View File

@@ -1 +0,0 @@
package processing

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"time"
@@ -22,23 +21,24 @@ import (
"github.com/imgproxy/imgproxy/v3/testutil"
)
const (
testDataPath = "../../testdata"
)
type HandlerTestSuite struct {
testutil.LazySuite
testData *testutil.TestDataProvider
rwConf testutil.LazyObj[*responsewriter.Config]
rwFactory testutil.LazyObj[*responsewriter.Factory]
config testutil.LazyObj[*Config]
handler testutil.LazyObj[*Handler]
testServer testutil.LazyTestServer
}
func (s *HandlerTestSuite) SetupSuite() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
s.testData = testutil.NewTestDataProvider(s.T)
s.rwConf, _ = testutil.NewLazySuiteObj(
s,
@@ -67,6 +67,7 @@ func (s *HandlerTestSuite) SetupSuite() {
s,
func() (*Handler, error) {
fc := fetcher.NewDefaultConfig()
fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
fetcher, err := fetcher.New(&fc)
s.Require().NoError(err)
@@ -75,36 +76,27 @@ func (s *HandlerTestSuite) SetupSuite() {
},
)
s.testServer, _ = testutil.NewLazySuiteTestServer(s)
// Silence logs during tests
logrus.SetOutput(io.Discard)
}
func (s *HandlerTestSuite) TearDownSuite() {
config.Reset()
logrus.SetOutput(os.Stdout)
}
func (s *HandlerTestSuite) SetupTest() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
}
func (s *HandlerTestSuite) SetupSubTest() {
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
s.ResetLazyObjects()
}
func (s *HandlerTestSuite) readTestFile(name string) []byte {
data, err := os.ReadFile(filepath.Join(testDataPath, name))
s.Require().NoError(err)
return data
}
func (s *HandlerTestSuite) execute(
imageURL string,
header http.Header,
po *options.ProcessingOptions,
) *httptest.ResponseRecorder {
) *http.Response {
imageURL = s.testServer().URL() + imageURL
req := httptest.NewRequest("GET", "/", nil)
httpheaders.CopyAll(header, req.Header, true)
@@ -115,51 +107,42 @@ func (s *HandlerTestSuite) execute(
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
s.Require().NoError(err)
return rw
return rw.Result()
}
// TestHandlerBasicRequest checks basic streaming request
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
// Verify we get the original image data
actual := rw.Body.Bytes()
actual, err := io.ReadAll(res.Body)
s.Require().NoError(err)
s.Require().Equal(data, actual)
}
// TestHandlerResponseHeadersPassthrough checks that original response headers are
// passed through to the client
func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
contentLength := len(data)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpheaders.ContentType, "image/png")
w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
w.Header().Set(httpheaders.AcceptRanges, "bytes")
w.Header().Set(httpheaders.Etag, "etag")
w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().SetHeaders(
httpheaders.ContentType, "image/png",
httpheaders.ContentLength, strconv.Itoa(contentLength),
httpheaders.AcceptRanges, "bytes",
httpheaders.Etag, "etag",
httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
).SetBody(data)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
@@ -172,42 +155,34 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
// to the server
func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
etag := `"test-etag-123"`
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify that If-None-Match header is passed through
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
w.Header().Set(httpheaders.Etag, etag)
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().
SetBody(data).
SetHeaders(httpheaders.Etag, etag).
SetHook(func(r *http.Request, rw http.ResponseWriter) {
// Verify that If-None-Match header is passed through
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
})
h := make(http.Header)
h.Set(httpheaders.IfNoneMatch, etag)
h.Set(httpheaders.AcceptEncoding, "gzip")
h.Set(httpheaders.Range, "bytes=*")
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
res := s.execute("", h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
}
// TestHandlerContentDisposition checks that Content-Disposition header is set correctly
func (s *HandlerTestSuite) TestHandlerContentDisposition() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
po := &options.ProcessingOptions{
Filename: "custom_name",
@@ -215,10 +190,8 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
}
// Use a URL with a .png extension to help content disposition logic
imageURL := ts.URL + "/test.png"
rw := s.execute(imageURL, nil, po)
res := s.execute("/test.png", nil, po)
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
@@ -229,7 +202,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
type testCase struct {
name string
cacheControlPassthrough bool
setupOriginHeaders func(http.ResponseWriter)
setupOriginHeaders func()
timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
expectedStatusCode int
validate func(*testing.T, *http.Response)
@@ -250,8 +223,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "Passthrough",
cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) {
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
setupOriginHeaders: func() {
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
},
timestampOffset: nil,
expectedStatusCode: 200,
@@ -263,8 +236,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "ExpiresPassthrough",
cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) {
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
setupOriginHeaders: func() {
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
},
timestampOffset: nil,
expectedStatusCode: 200,
@@ -278,8 +251,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "PassthroughDisabled",
cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) {
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
setupOriginHeaders: func() {
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
},
timestampOffset: nil,
expectedStatusCode: 200,
@@ -291,7 +264,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "WithProcessingOptionsExpires",
cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
timestampOffset: &oneHour,
expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) {
@@ -303,9 +275,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "ProcessingOptionsOverridesOrigin",
cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) {
setupOriginHeaders: func() {
// Origin has a longer cache time
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
},
timestampOffset: &thirtyMinutes,
expectedStatusCode: 200,
@@ -318,10 +290,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "BothHeadersPassthroughEnabled",
cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) {
setupOriginHeaders: func() {
// Origin has both Cache-Control and Expires headers
w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
},
timestampOffset: nil,
expectedStatusCode: 200,
@@ -336,10 +308,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "ProcessingOptionsOverridesBothOriginHeaders",
cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) {
setupOriginHeaders: func() {
// Origin has both Cache-Control and Expires headers with longer cache times
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
},
timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
expectedStatusCode: 200,
@@ -352,7 +324,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{
name: "NoOriginHeaders",
cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
timestampOffset: nil,
expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) {
@@ -363,15 +334,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
for _, tc := range testCases {
s.Run(tc.name, func() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tc.setupOriginHeaders(w)
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
if tc.setupOriginHeaders != nil {
tc.setupOriginHeaders()
}
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
s.rwConf().DefaultTTL = 4242
@@ -383,9 +352,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
po.Expires = &expires
}
rw := s.execute(ts.URL, nil, po)
res := rw.Result()
res := s.execute("", nil, po)
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
tc.validate(s.T(), res)
})
@@ -405,85 +372,64 @@ func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
// TestHandlerSecurityHeaders tests the security headers set by the streaming service.
func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
s.Require().Equal(http.StatusOK, res.StatusCode)
s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
}
// TestHandlerErrorResponse tests the error responses from the streaming service.
func (s *HandlerTestSuite) TestHandlerErrorResponse() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("Not Found"))
}))
defer ts.Close()
s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(404, res.StatusCode)
s.Require().Equal(http.StatusNotFound, res.StatusCode)
}
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
s.config().CookiePassthrough = true
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify cookies are passed through
cookie, cerr := r.Cookie("test_cookie")
if cerr == nil {
s.Equal("test_value", cookie.Value)
}
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().
SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
SetHook(func(r *http.Request, rw http.ResponseWriter) {
// Verify cookies are passed through
cookie, cerr := r.Cookie("test_cookie")
if cerr == nil {
s.Equal("test_value", cookie.Value)
}
}).SetBody(data)
h := make(http.Header)
h.Set(httpheaders.Cookie, "test_cookie=test_value")
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
res := s.execute("", h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
}
// TestHandlerCanonicalHeader tests that the canonical header is set correctly
func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
data := s.readTestFile("test1.png")
data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
for _, sc := range []bool{true, false} {
s.rwConf().SetCanonicalHeader = sc
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
if sc {
s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
} else {
s.Require().Empty(res.Header.Get(httpheaders.Link))
}

View File

@@ -6,17 +6,13 @@ import (
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors"
@@ -25,88 +21,70 @@ import (
)
type ImageDataTestSuite struct {
suite.Suite
testutil.LazySuite
server *httptest.Server
fetcherCfg testutil.LazyObj[*fetcher.Config]
factory testutil.LazyObj[*Factory]
testServer testutil.LazyTestServer
status int
data []byte
header http.Header
check func(*http.Request)
factory *Factory
defaultData []byte
data []byte
}
func (s *ImageDataTestSuite) SetupSuite() {
config.Reset()
config.ClientKeepAliveTimeout = 0
s.data = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
f, err := os.Open("../testdata/test1.jpg")
s.Require().NoError(err)
defer f.Close()
s.fetcherCfg, _ = testutil.NewLazySuiteObj(
s,
func() (*fetcher.Config, error) {
c := fetcher.NewDefaultConfig()
c.Transport.HTTP.AllowLoopbackSourceAddresses = true
c.Transport.HTTP.ClientKeepAliveTimeout = 0
data, err := io.ReadAll(f)
s.Require().NoError(err)
return &c, nil
},
)
s.defaultData = data
s.factory, _ = testutil.NewLazySuiteObj(
s,
func() (*Factory, error) {
fetcher, err := fetcher.New(s.fetcherCfg())
if err != nil {
return nil, err
}
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if s.check != nil {
s.check(r)
}
return NewFactory(fetcher), nil
},
)
httpheaders.CopyAll(s.header, rw.Header(), true)
s.testServer, _ = testutil.NewLazySuiteTestServer(
s,
func(srv *testutil.TestServer) error {
// Default headers and body for 200 OK response
srv.SetHeaders(
httpheaders.ContentType, "image/jpeg",
httpheaders.ContentLength, strconv.Itoa(len(s.data)),
).SetBody(s.data)
data := s.data
if data == nil {
data = s.defaultData
}
rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
rw.WriteHeader(s.status)
rw.Write(data)
}))
c, err := fetcher.LoadConfigFromEnv(nil)
s.Require().NoError(err)
fetcher, err := fetcher.New(c)
s.Require().NoError(err)
s.factory = NewFactory(fetcher)
return nil
},
)
}
func (s *ImageDataTestSuite) TearDownSuite() {
s.server.Close()
}
func (s *ImageDataTestSuite) SetupTest() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
s.status = http.StatusOK
s.data = nil
s.check = nil
s.header = http.Header{}
s.header.Set("Content-Type", "image/jpeg")
func (s *ImageDataTestSuite) SetupSubTest() {
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
s.ResetLazyObjects()
}
func (s *ImageDataTestSuite) TestDownloadStatusOK() {
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().NoError(err)
s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format())
}
func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
s.status = http.StatusPartialContent
testCases := []struct {
name string
contentRange string
@@ -114,17 +92,17 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
}{
{
name: "Full Content-Range",
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-1, len(s.data)),
expectErr: false,
},
{
name: "Partial Content-Range, early end",
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-2, len(s.defaultData)),
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-2, len(s.data)),
expectErr: true,
},
{
name: "Partial Content-Range, late start",
contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.data)-1, len(s.data)),
expectErr: true,
},
{
@@ -139,39 +117,41 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
},
{
name: "Unknown Content-Range range",
contentRange: fmt.Sprintf("bytes */%d", len(s.defaultData)),
contentRange: fmt.Sprintf("bytes */%d", len(s.data)),
expectErr: true,
},
{
name: "Unknown Content-Range size, full range",
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-1),
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-1),
expectErr: false,
},
{
name: "Unknown Content-Range size, early end",
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-2),
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-2),
expectErr: true,
},
{
name: "Unknown Content-Range size, late start",
contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.defaultData)-1),
contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.data)-1),
expectErr: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
s.header.Set("Content-Range", tc.contentRange)
s.testServer().
SetHeaders(httpheaders.ContentRange, tc.contentRange).
SetStatusCode(http.StatusPartialContent)
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
if tc.expectErr {
s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
s.Require().Equal(http.StatusNotFound, ierrors.Wrap(err, 0).StatusCode())
} else {
s.Require().NoError(err)
s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format())
}
})
@@ -179,11 +159,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
}
func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
s.status = http.StatusNotFound
s.data = []byte("Not Found")
s.header.Set("Content-Type", "text/plain")
s.testServer().
SetStatusCode(http.StatusNotFound).
SetBody([]byte("Not Found")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -191,11 +172,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
}
func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
s.status = http.StatusForbidden
s.data = []byte("Forbidden")
s.header.Set("Content-Type", "text/plain")
s.testServer().
SetStatusCode(http.StatusForbidden).
SetBody([]byte("Forbidden")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -203,11 +185,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
}
func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() {
s.status = http.StatusInternalServerError
s.data = []byte("Internal Server Error")
s.header.Set("Content-Type", "text/plain")
s.testServer().
SetStatusCode(http.StatusInternalServerError).
SetBody([]byte("Internal Server Error")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -221,7 +204,7 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
serverURL := fmt.Sprintf("http://%s", l.Addr().String())
imgdata, _, err := s.factory.DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -229,19 +212,19 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
}
func (s *ImageDataTestSuite) TestDownloadInvalidImage() {
s.data = []byte("invalid")
s.testServer().SetBody([]byte("invalid"))
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
s.Require().Equal(http.StatusUnprocessableEntity, ierrors.Wrap(err, 0).StatusCode())
s.Require().Nil(imgdata)
}
func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
config.AllowLoopbackSourceAddresses = false
s.fetcherCfg().Transport.HTTP.AllowLoopbackSourceAddresses = false
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -249,11 +232,10 @@ func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
}
func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() {
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{
MaxSrcFileSize: 1,
})
fmt.Println(err)
s.Require().Error(err)
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
s.Require().Nil(imgdata)
@@ -263,39 +245,43 @@ func (s *ImageDataTestSuite) TestDownloadGzip() {
buf := new(bytes.Buffer)
enc := gzip.NewWriter(buf)
_, err := enc.Write(s.defaultData)
_, err := enc.Write(s.data)
s.Require().NoError(err)
err = enc.Close()
s.Require().NoError(err)
s.data = buf.Bytes()
s.header.Set("Content-Encoding", "gzip")
s.testServer().
SetBody(buf.Bytes()).
SetHeaders(
httpheaders.ContentEncoding, "gzip",
httpheaders.ContentLength, strconv.Itoa(buf.Len()), // Update Content-Length
)
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().NoError(err)
s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format())
}
func (s *ImageDataTestSuite) TestFromFile() {
imgdata, err := s.factory.NewFromPath("../testdata/test1.jpg")
imgdata, err := s.factory().NewFromPath("../testdata/test1.jpg")
s.Require().NoError(err)
s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format())
}
func (s *ImageDataTestSuite) TestFromBase64() {
b64 := base64.StdEncoding.EncodeToString(s.defaultData)
b64 := base64.StdEncoding.EncodeToString(s.data)
imgdata, err := s.factory.NewFromBase64(b64)
imgdata, err := s.factory().NewFromBase64(b64)
s.Require().NoError(err)
s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format())
}

View File

@@ -27,10 +27,7 @@ type ProcessingHandlerTestSuite struct {
func (s *ProcessingHandlerTestSuite) SetupTest() {
config.Reset() // We reset config only at the start of each test
// NOTE: This must be moved to security config
config.AllowLoopbackSourceAddresses = true
// NOTE: end note
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
}
func (s *ProcessingHandlerTestSuite) SetupSubTest() {
@@ -142,13 +139,13 @@ func (s *ProcessingHandlerTestSuite) TestSourceNetworkValidation() {
// We wrap this in a subtest to reset s.router()
s.Run("AllowLoopbackSourceAddressesTrue", func() {
config.AllowLoopbackSourceAddresses = true
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
res := s.GET(url)
s.Require().Equal(http.StatusOK, res.StatusCode)
})
s.Run("AllowLoopbackSourceAddressesFalse", func() {
config.AllowLoopbackSourceAddresses = false
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = false
res := s.GET(url)
s.Require().Equal(http.StatusNotFound, res.StatusCode)
})
@@ -256,7 +253,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
}
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() {
config.CacheControlPassthrough = true
s.Config().Server.ResponseWriter.CacheControlPassthrough = true
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat))
@@ -290,7 +287,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughDisabled() {
}
func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
config.ETagEnabled = false
s.Config().Handlers.Processing.ETagEnabled = false
res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png")
@@ -299,7 +296,7 @@ func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
}
func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
config.ETagEnabled = true
s.Config().Handlers.Processing.ETagEnabled = true
etag := `"loremipsumdolor"`
@@ -321,7 +318,8 @@ func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
}
func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
config.LastModifiedEnabled = true
s.Config().Handlers.Processing.LastModifiedEnabled = true
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
rw.WriteHeader(200)
@@ -335,7 +333,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
}
func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
config.LastModifiedEnabled = false
s.Config().Handlers.Processing.LastModifiedEnabled = false
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
rw.WriteHeader(200)
@@ -349,7 +347,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
}
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() {
config.LastModifiedEnabled = false
s.Config().Handlers.Processing.LastModifiedEnabled = false
data := s.TestData.Read("test1.png")
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -368,7 +366,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedD
}
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() {
config.LastModifiedEnabled = true
s.Config().Handlers.Processing.LastModifiedEnabled = true
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
@@ -386,7 +384,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedE
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() {
data := s.TestData.Read("test1.png")
config.LastModifiedEnabled = false
s.Config().Handlers.Processing.LastModifiedEnabled = false
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
s.Empty(modifiedSince)

View File

@@ -1082,10 +1082,10 @@ func applyURLOption(po *ProcessingOptions, name string, args []string, usedPrese
}
func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error {
allowAll = allowAll || len(config.AllowedProcessiongOptions) == 0
allowAll = allowAll || len(config.AllowedProcessingOptions) == 0
for _, opt := range options {
if !allowAll && !slices.Contains(config.AllowedProcessiongOptions, opt.Name) {
if !allowAll && !slices.Contains(config.AllowedProcessingOptions, opt.Name) {
return newForbiddenOptionError("processing", opt.Name)
}

View File

@@ -646,7 +646,7 @@ func (s *ProcessingOptionsTestSuite) TestParseBase64URLOnlyPresets() {
}
func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() {
config.AllowedProcessiongOptions = []string{"w", "h", "pr"}
config.AllowedProcessingOptions = []string{"w", "h", "pr"}
presets["test1"] = urlOptions{
urlOption{Name: "blur", Args: []string{"0.2"}},

View File

@@ -13,7 +13,6 @@ type (
ImageResolutionError string
SecurityOptionsError struct{}
SourceURLError string
SourceAddressError string
)
func newSignatureError(msg string) error {
@@ -75,15 +74,3 @@ func newSourceURLError(imageURL string) error {
}
func (e SourceURLError) Error() string { return string(e) }
func newSourceAddressError(msg string) error {
return ierrors.Wrap(
SourceAddressError(msg),
1,
ierrors.WithStatusCode(http.StatusNotFound),
ierrors.WithPublicMessage("Invalid source URL"),
ierrors.WithShouldReport(false),
)
}
func (e SourceAddressError) Error() string { return string(e) }

View File

@@ -1,9 +1,6 @@
package security
import (
"fmt"
"net"
"github.com/imgproxy/imgproxy/v3/config"
)
@@ -20,29 +17,3 @@ func VerifySourceURL(imageURL string) error {
return newSourceURLError(imageURL)
}
func VerifySourceNetwork(addr string) error {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
}
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
}
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
}
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
}
return nil
}

View File

@@ -51,7 +51,7 @@ func NewLazySuiteObj[T any](
// Get the [LazySuite] instance
lazy := s.Lazy()
// Create the [LazyObj] instance
obj, cancel := NewLazyObj(lazy, newFn, dropFn...)
obj, cancel := newLazyObj(lazy, newFn, dropFn...)
// Add cleanup function to the resets list
lazy.resets = append(lazy.resets, cancel)

View File

@@ -23,10 +23,10 @@ type LazyObjNew[T any] func() (T, error)
// If the object was not yet initialized, the callback is not called.
type LazyObjDrop[T any] func(T) error
// NewLazyObj creates a new [LazyObj] that initializes the object on the first call.
// newLazyObj creates a new [LazyObj] that initializes the object on the first call.
// It returns a function that can be called to get the object and a cancel function
// that can be called to reset the object.
func NewLazyObj[T any](
func newLazyObj[T any](
s LazyObjT,
newFn LazyObjNew[T],
dropFn ...LazyObjDrop[T],

View File

@@ -26,9 +26,6 @@ type TestDataProvider struct {
// New creates a new TestDataProvider
func NewTestDataProvider(t TestDataProviderT) *TestDataProvider {
// if h, ok := t.(interface{ Helper() }); ok {
// h.Helper()
// }
t().Helper()
path, err := findProjectRoot()

122
testutil/test_server.go Normal file
View File

@@ -0,0 +1,122 @@
package testutil
import (
"context"
"net/http"
"net/http/httptest"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/stretchr/testify/require"
)
// TestServerHookFunc is a function type for in-request hooks
type TestServerHookFunc func(r *http.Request, rw http.ResponseWriter)
// Sugar alias
type LazyTestServer = LazyObj[*TestServer]
// TestServer is a syntax sugar wrapper over httptest.Server
type TestServer struct {
testServer *httptest.Server
status int
data []byte
header http.Header
hook TestServerHookFunc
}
// NewLazySuiteTestServer creates a lazy TestServer object for use in test suites
func NewLazySuiteTestServer(
l LazySuiteFrom,
init ...func(*TestServer) error,
) (LazyObj[*TestServer], context.CancelFunc) {
return NewLazySuiteObj(
l,
func() (*TestServer, error) {
s := NewTestServer()
if len(init) > 0 {
for _, fn := range init {
if fn == nil {
continue
}
err := fn(s)
require.NoError(l.Lazy().T(), err, "Failed to reset test server")
}
}
return s, nil
},
func(s *TestServer) error {
s.Close()
return nil
},
)
}
// New creates and starts new http.TestServer
func NewTestServer() *TestServer {
ts := &TestServer{
status: http.StatusOK,
header: make(http.Header),
data: nil,
hook: nil,
}
return ts.start()
}
// SetStatusCode sets the status code that will be returned by the server
func (s *TestServer) SetStatusCode(status int) *TestServer {
s.status = status
return s
}
// SetBody sets the body that will be returned by the server
func (s *TestServer) SetBody(data []byte) *TestServer {
s.data = data
return s
}
// WithHeader adds headers that will be returned by the server.
// Odd arguments are treated as keys, even arguments as values.
func (s *TestServer) SetHeaders(kv ...string) *TestServer {
for i := 0; i+1 < len(kv); i += 2 {
key := kv[i]
value := kv[i+1]
s.header.Set(key, value)
}
return s
}
// SetHook sets a function that will be called on each request. It is called
// after headsers are set, but before status and body are written.
func (s *TestServer) SetHook(f TestServerHookFunc) *TestServer {
s.hook = f
return s
}
// Start starts the server
func (s *TestServer) start() *TestServer {
s.testServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpheaders.CopyAll(s.header, w.Header(), true)
if s.hook != nil {
s.hook(r, w)
}
w.WriteHeader(s.status)
w.Write(s.data)
}))
return s
}
// Close stops the server
func (s *TestServer) Close() {
s.testServer.Close()
}
// URL returns the server URL
func (s *TestServer) URL() string {
return s.testServer.URL
}