TestServer, AllowNetworks -> http.Transport

This commit is contained in:
Viktor Sokolov
2025-09-11 10:15:31 +02:00
parent 1f6d007948
commit 246ea28864
19 changed files with 425 additions and 388 deletions

View File

@@ -3,77 +3,56 @@ package auximageprovider
import ( import (
"encoding/base64" "encoding/base64"
"io" "io"
"net/http"
"net/http/httptest"
"os"
"strconv" "strconv"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/fetcher" "github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/testutil"
) )
type ImageProviderTestSuite struct { type ImageProviderTestSuite struct {
suite.Suite testutil.LazySuite
server *httptest.Server
testData []byte testData []byte
testDataB64 string testDataB64 string
// Server state testServer testutil.LazyTestServer
status int idf *imagedata.Factory
data []byte
header http.Header
} }
func (s *ImageProviderTestSuite) SetupSuite() { func (s *ImageProviderTestSuite) SetupSuite() {
config.Reset() s.testData = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
config.AllowLoopbackSourceAddresses = true s.testDataB64 = base64.StdEncoding.EncodeToString(s.testData)
// Load test image data fc := fetcher.NewDefaultConfig()
f, err := os.Open("../testdata/test1.jpg") fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
s.Require().NoError(err)
defer f.Close()
data, err := io.ReadAll(f) f, err := fetcher.New(&fc)
s.Require().NoError(err) s.Require().NoError(err)
s.testData = data s.idf = imagedata.NewFactory(f)
s.testDataB64 = base64.StdEncoding.EncodeToString(data)
// Create test server s.testServer, _ = testutil.NewLazySuiteTestServer(
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { s,
for k, vv := range s.header { func(srv *testutil.TestServer) error {
for _, v := range vv { srv.SetHeaders(
rw.Header().Add(k, v) httpheaders.ContentType, "image/jpeg",
} httpheaders.ContentLength, strconv.Itoa(len(s.testData)),
} ).SetBody(s.testData)
data := s.data return nil
if data == nil { },
data = s.testData )
}
rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data)))
rw.WriteHeader(s.status)
rw.Write(data)
}))
} }
func (s *ImageProviderTestSuite) TearDownSuite() { func (s *ImageProviderTestSuite) SetupSubTest() {
s.server.Close() // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
} s.ResetLazyObjects()
func (s *ImageProviderTestSuite) SetupTest() {
s.status = http.StatusOK
s.data = nil
s.header = http.Header{}
s.header.Set(httpheaders.ContentType, "image/jpeg")
} }
// Helper function to read data from ImageData // Helper function to read data from ImageData
@@ -114,7 +93,7 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
}, },
{ {
name: "URL", name: "URL",
config: &StaticConfig{URL: s.server.URL}, config: &StaticConfig{URL: s.testServer().URL()},
validateFunc: func(provider Provider) { validateFunc: func(provider Provider) {
s.Equal(s.testData, s.readImageData(provider)) s.Equal(s.testData, s.readImageData(provider))
}, },
@@ -149,10 +128,12 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
}, },
{ {
name: "HeadersPassedThrough", name: "HeadersPassedThrough",
config: &StaticConfig{URL: s.server.URL}, config: &StaticConfig{URL: s.testServer().URL()},
setupFunc: func() { setupFunc: func() {
s.header.Set("X-Custom-Header", "test-value") s.testServer().SetHeaders(
s.header.Set(httpheaders.CacheControl, "max-age=3600") "X-Custom-Header", "test-value",
httpheaders.CacheControl, "max-age=3600",
)
}, },
validateFunc: func(provider Provider) { validateFunc: func(provider Provider) {
imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{}) imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{})
@@ -167,19 +148,13 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
}, },
} }
fc := fetcher.NewDefaultConfig()
f, err := fetcher.New(&fc)
s.Require().NoError(err)
idf := imagedata.NewFactory(f)
for _, tt := range tests { for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) { s.T().Run(tt.name, func(t *testing.T) {
if tt.setupFunc != nil { if tt.setupFunc != nil {
tt.setupFunc() tt.setupFunc()
} }
provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", idf) provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", s.idf)
if tt.expectError { if tt.expectError {
s.Require().Error(err) s.Require().Error(err)

View File

@@ -58,7 +58,7 @@ var (
PngUnlimited bool PngUnlimited bool
SvgUnlimited bool SvgUnlimited bool
MaxResultDimension int MaxResultDimension int
AllowedProcessiongOptions []string AllowedProcessingOptions []string
AllowSecurityOptions bool AllowSecurityOptions bool
JpegProgressive bool JpegProgressive bool
@@ -267,7 +267,7 @@ func Reset() {
PngUnlimited = false PngUnlimited = false
SvgUnlimited = false SvgUnlimited = false
MaxResultDimension = 0 MaxResultDimension = 0
AllowedProcessiongOptions = make([]string, 0) AllowedProcessingOptions = make([]string, 0)
AllowSecurityOptions = false AllowSecurityOptions = false
JpegProgressive = false JpegProgressive = false
@@ -502,7 +502,7 @@ func Configure() error {
configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED") configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED")
configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION") configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION")
configurators.StringSlice(&AllowedProcessiongOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS") configurators.StringSlice(&AllowedProcessingOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS") configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS")

View File

@@ -4,10 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"github.com/imgproxy/imgproxy/v3/fetcher/transport/generichttp"
"github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/security"
) )
const msgSourceImageIsUnreachable = "Source image is unreachable" const msgSourceImageIsUnreachable = "Source image is unreachable"
@@ -157,13 +158,21 @@ func (e NotModifiedError) Headers() http.Header {
func WrapError(err error) error { func WrapError(err error) error {
isTimeout := false isTimeout := false
var secArrdErr security.SourceAddressError var secArrdErr generichttp.SourceAddressError
switch { switch {
case errors.Is(err, context.DeadlineExceeded): case errors.Is(err, context.DeadlineExceeded):
isTimeout = true isTimeout = true
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return newImageRequestCanceledError(err) return newImageRequestCanceledError(err)
case err == io.ErrUnexpectedEOF:
return ierrors.Wrap(
newImageRequestError(err),
1,
ierrors.WithPublicMessage("source image is corrupted"),
ierrors.WithShouldReport(false),
ierrors.WithStatusCode(http.StatusUnprocessableEntity),
)
case errors.As(err, &secArrdErr): case errors.As(err, &secArrdErr):
return ierrors.Wrap( return ierrors.Wrap(
err, err,

View File

@@ -10,15 +10,21 @@ import (
// Config holds the configuration for the generic HTTP transport // Config holds the configuration for the generic HTTP transport
type Config struct { type Config struct {
ClientKeepAliveTimeout time.Duration ClientKeepAliveTimeout time.Duration
IgnoreSslVerification bool IgnoreSslVerification bool
AllowLoopbackSourceAddresses bool
AllowLinkLocalSourceAddresses bool
AllowPrivateSourceAddresses bool
} }
// NewDefaultConfig returns a new default configuration for the generic HTTP transport // NewDefaultConfig returns a new default configuration for the generic HTTP transport
func NewDefaultConfig() Config { func NewDefaultConfig() Config {
return Config{ return Config{
ClientKeepAliveTimeout: 90 * time.Second, ClientKeepAliveTimeout: 90 * time.Second,
IgnoreSslVerification: false, IgnoreSslVerification: false,
AllowLoopbackSourceAddresses: false,
AllowLinkLocalSourceAddresses: false,
AllowPrivateSourceAddresses: true,
} }
} }
@@ -28,6 +34,9 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
c.IgnoreSslVerification = config.IgnoreSslVerification c.IgnoreSslVerification = config.IgnoreSslVerification
c.AllowLinkLocalSourceAddresses = config.AllowLinkLocalSourceAddresses
c.AllowLoopbackSourceAddresses = config.AllowLoopbackSourceAddresses
c.AllowPrivateSourceAddresses = config.AllowPrivateSourceAddresses
return c, nil return c, nil
} }

View File

@@ -0,0 +1,23 @@
package generichttp
import (
"net/http"
"github.com/imgproxy/imgproxy/v3/ierrors"
)
type (
SourceAddressError string
)
func newSourceAddressError(msg string) error {
return ierrors.Wrap(
SourceAddressError(msg),
1,
ierrors.WithStatusCode(http.StatusNotFound),
ierrors.WithPublicMessage("Invalid source URL"),
ierrors.WithShouldReport(false),
)
}
func (e SourceAddressError) Error() string { return string(e) }

View File

@@ -3,12 +3,12 @@ package generichttp
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/http" "net/http"
"syscall" "syscall"
"time" "time"
"github.com/imgproxy/imgproxy/v3/security"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@@ -25,7 +25,7 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
if verifyNetworks { if verifyNetworks {
dialer.Control = func(network, address string, c syscall.RawConn) error { dialer.Control = func(network, address string, c syscall.RawConn) error {
return security.VerifySourceNetwork(address) return verifySourceNetwork(address, config)
} }
} }
@@ -66,3 +66,29 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
return transport, nil return transport, nil
} }
func verifySourceNetwork(addr string, config *Config) error {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
}
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
}
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
}
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
}
return nil
}

View File

@@ -1,9 +1,8 @@
package security package generichttp
import ( import (
"testing" "testing"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -100,24 +99,14 @@ func TestVerifySourceNetwork(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Backup original config config := NewDefaultConfig()
originalLoopback := config.AllowLoopbackSourceAddresses
originalLinkLocal := config.AllowLinkLocalSourceAddresses
originalPrivate := config.AllowPrivateSourceAddresses
// Restore original config after test
defer func() {
config.AllowLoopbackSourceAddresses = originalLoopback
config.AllowLinkLocalSourceAddresses = originalLinkLocal
config.AllowPrivateSourceAddresses = originalPrivate
}()
// Override config for the test // Override config for the test
config.AllowLoopbackSourceAddresses = tc.allowLoopback config.AllowLoopbackSourceAddresses = tc.allowLoopback
config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal
config.AllowPrivateSourceAddresses = tc.allowPrivate config.AllowPrivateSourceAddresses = tc.allowPrivate
err := VerifySourceNetwork(tc.addr) err := verifySourceNetwork(tc.addr, &config)
if tc.expectErr { if tc.expectErr {
require.Error(t, err) require.Error(t, err)

View File

@@ -1 +0,0 @@
package processing

View File

@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@@ -22,23 +21,24 @@ import (
"github.com/imgproxy/imgproxy/v3/testutil" "github.com/imgproxy/imgproxy/v3/testutil"
) )
const (
testDataPath = "../../testdata"
)
type HandlerTestSuite struct { type HandlerTestSuite struct {
testutil.LazySuite testutil.LazySuite
testData *testutil.TestDataProvider
rwConf testutil.LazyObj[*responsewriter.Config] rwConf testutil.LazyObj[*responsewriter.Config]
rwFactory testutil.LazyObj[*responsewriter.Factory] rwFactory testutil.LazyObj[*responsewriter.Factory]
config testutil.LazyObj[*Config] config testutil.LazyObj[*Config]
handler testutil.LazyObj[*Handler] handler testutil.LazyObj[*Handler]
testServer testutil.LazyTestServer
} }
func (s *HandlerTestSuite) SetupSuite() { func (s *HandlerTestSuite) SetupSuite() {
config.Reset() config.Reset()
config.AllowLoopbackSourceAddresses = true
s.testData = testutil.NewTestDataProvider(s.T)
s.rwConf, _ = testutil.NewLazySuiteObj( s.rwConf, _ = testutil.NewLazySuiteObj(
s, s,
@@ -67,6 +67,7 @@ func (s *HandlerTestSuite) SetupSuite() {
s, s,
func() (*Handler, error) { func() (*Handler, error) {
fc := fetcher.NewDefaultConfig() fc := fetcher.NewDefaultConfig()
fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
fetcher, err := fetcher.New(&fc) fetcher, err := fetcher.New(&fc)
s.Require().NoError(err) s.Require().NoError(err)
@@ -75,36 +76,27 @@ func (s *HandlerTestSuite) SetupSuite() {
}, },
) )
s.testServer, _ = testutil.NewLazySuiteTestServer(s)
// Silence logs during tests // Silence logs during tests
logrus.SetOutput(io.Discard) logrus.SetOutput(io.Discard)
} }
func (s *HandlerTestSuite) TearDownSuite() { func (s *HandlerTestSuite) TearDownSuite() {
config.Reset()
logrus.SetOutput(os.Stdout) logrus.SetOutput(os.Stdout)
} }
func (s *HandlerTestSuite) SetupTest() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
}
func (s *HandlerTestSuite) SetupSubTest() { func (s *HandlerTestSuite) SetupSubTest() {
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
s.ResetLazyObjects() s.ResetLazyObjects()
} }
func (s *HandlerTestSuite) readTestFile(name string) []byte {
data, err := os.ReadFile(filepath.Join(testDataPath, name))
s.Require().NoError(err)
return data
}
func (s *HandlerTestSuite) execute( func (s *HandlerTestSuite) execute(
imageURL string, imageURL string,
header http.Header, header http.Header,
po *options.ProcessingOptions, po *options.ProcessingOptions,
) *httptest.ResponseRecorder { ) *http.Response {
imageURL = s.testServer().URL() + imageURL
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
httpheaders.CopyAll(header, req.Header, true) httpheaders.CopyAll(header, req.Header, true)
@@ -115,51 +107,42 @@ func (s *HandlerTestSuite) execute(
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww) err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
s.Require().NoError(err) s.Require().NoError(err)
return rw return rw.Result()
} }
// TestHandlerBasicRequest checks basic streaming request // TestHandlerBasicRequest checks basic streaming request
func (s *HandlerTestSuite) TestHandlerBasicRequest() { func (s *HandlerTestSuite) TestHandlerBasicRequest() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType)) s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
// Verify we get the original image data // Verify we get the original image data
actual := rw.Body.Bytes() actual, err := io.ReadAll(res.Body)
s.Require().NoError(err)
s.Require().Equal(data, actual) s.Require().Equal(data, actual)
} }
// TestHandlerResponseHeadersPassthrough checks that original response headers are // TestHandlerResponseHeadersPassthrough checks that original response headers are
// passed through to the client // passed through to the client
func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() { func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
contentLength := len(data) contentLength := len(data)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetHeaders(
w.Header().Set(httpheaders.ContentType, "image/png") httpheaders.ContentType, "image/png",
w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength)) httpheaders.ContentLength, strconv.Itoa(contentLength),
w.Header().Set(httpheaders.AcceptRanges, "bytes") httpheaders.AcceptRanges, "bytes",
w.Header().Set(httpheaders.Etag, "etag") httpheaders.Etag, "etag",
w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
w.WriteHeader(200) ).SetBody(data)
w.Write(data)
}))
defer ts.Close()
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType)) s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength)) s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
@@ -172,42 +155,34 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
// to the server // to the server
func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() { func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
etag := `"test-etag-123"` etag := `"test-etag-123"`
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().
// Verify that If-None-Match header is passed through SetBody(data).
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch)) SetHeaders(httpheaders.Etag, etag).
s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding)) SetHook(func(r *http.Request, rw http.ResponseWriter) {
s.Equal("bytes=*", r.Header.Get(httpheaders.Range)) // Verify that If-None-Match header is passed through
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
w.Header().Set(httpheaders.Etag, etag) s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
w.WriteHeader(200) s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
w.Write(data) })
}))
defer ts.Close()
h := make(http.Header) h := make(http.Header)
h.Set(httpheaders.IfNoneMatch, etag) h.Set(httpheaders.IfNoneMatch, etag)
h.Set(httpheaders.AcceptEncoding, "gzip") h.Set(httpheaders.AcceptEncoding, "gzip")
h.Set(httpheaders.Range, "bytes=*") h.Set(httpheaders.Range, "bytes=*")
rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) res := s.execute("", h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
s.Require().Equal(etag, res.Header.Get(httpheaders.Etag)) s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
} }
// TestHandlerContentDisposition checks that Content-Disposition header is set correctly // TestHandlerContentDisposition checks that Content-Disposition header is set correctly
func (s *HandlerTestSuite) TestHandlerContentDisposition() { func (s *HandlerTestSuite) TestHandlerContentDisposition() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
po := &options.ProcessingOptions{ po := &options.ProcessingOptions{
Filename: "custom_name", Filename: "custom_name",
@@ -215,10 +190,8 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
} }
// Use a URL with a .png extension to help content disposition logic // Use a URL with a .png extension to help content disposition logic
imageURL := ts.URL + "/test.png" res := s.execute("/test.png", nil, po)
rw := s.execute(imageURL, nil, po)
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png") s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment") s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
@@ -229,7 +202,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
type testCase struct { type testCase struct {
name string name string
cacheControlPassthrough bool cacheControlPassthrough bool
setupOriginHeaders func(http.ResponseWriter) setupOriginHeaders func()
timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
expectedStatusCode int expectedStatusCode int
validate func(*testing.T, *http.Response) validate func(*testing.T, *http.Response)
@@ -250,8 +223,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "Passthrough", name: "Passthrough",
cacheControlPassthrough: true, cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public") s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
}, },
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -263,8 +236,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "ExpiresPassthrough", name: "ExpiresPassthrough",
cacheControlPassthrough: true, cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
}, },
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -278,8 +251,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "PassthroughDisabled", name: "PassthroughDisabled",
cacheControlPassthrough: false, cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public") s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
}, },
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -291,7 +264,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "WithProcessingOptionsExpires", name: "WithProcessingOptionsExpires",
cacheControlPassthrough: false, cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
timestampOffset: &oneHour, timestampOffset: &oneHour,
expectedStatusCode: 200, expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) { validate: func(t *testing.T, res *http.Response) {
@@ -303,9 +275,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "ProcessingOptionsOverridesOrigin", name: "ProcessingOptionsOverridesOrigin",
cacheControlPassthrough: true, cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
// Origin has a longer cache time // Origin has a longer cache time
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public") s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
}, },
timestampOffset: &thirtyMinutes, timestampOffset: &thirtyMinutes,
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -318,10 +290,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "BothHeadersPassthroughEnabled", name: "BothHeadersPassthroughEnabled",
cacheControlPassthrough: true, cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
// Origin has both Cache-Control and Expires headers // Origin has both Cache-Control and Expires headers
w.Header().Set(httpheaders.CacheControl, "max-age=1800, public") s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat)) s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
}, },
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -336,10 +308,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "ProcessingOptionsOverridesBothOriginHeaders", name: "ProcessingOptionsOverridesBothOriginHeaders",
cacheControlPassthrough: true, cacheControlPassthrough: true,
setupOriginHeaders: func(w http.ResponseWriter) { setupOriginHeaders: func() {
// Origin has both Cache-Control and Expires headers with longer cache times // Origin has both Cache-Control and Expires headers with longer cache times
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public") s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat)) s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
}, },
timestampOffset: &fortyFiveMinutes, // Shorter than origin headers timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
expectedStatusCode: 200, expectedStatusCode: 200,
@@ -352,7 +324,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
{ {
name: "NoOriginHeaders", name: "NoOriginHeaders",
cacheControlPassthrough: false, cacheControlPassthrough: false,
setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) { validate: func(t *testing.T, res *http.Response) {
@@ -363,15 +334,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if tc.setupOriginHeaders != nil {
tc.setupOriginHeaders(w) tc.setupOriginHeaders()
w.Header().Set(httpheaders.ContentType, "image/png") }
w.WriteHeader(200)
w.Write(data) s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
}))
defer ts.Close()
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
s.rwConf().DefaultTTL = 4242 s.rwConf().DefaultTTL = 4242
@@ -383,9 +352,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
po.Expires = &expires po.Expires = &expires
} }
rw := s.execute(ts.URL, nil, po) res := s.execute("", nil, po)
res := rw.Result()
s.Require().Equal(tc.expectedStatusCode, res.StatusCode) s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
tc.validate(s.T(), res) tc.validate(s.T(), res)
}) })
@@ -405,85 +372,64 @@ func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
// TestHandlerSecurityHeaders tests the security headers set by the streaming service. // TestHandlerSecurityHeaders tests the security headers set by the streaming service.
func (s *HandlerTestSuite) TestHandlerSecurityHeaders() { func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result() s.Require().Equal(http.StatusOK, res.StatusCode)
s.Require().Equal(200, res.StatusCode)
s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy)) s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
} }
// TestHandlerErrorResponse tests the error responses from the streaming service. // TestHandlerErrorResponse tests the error responses from the streaming service.
func (s *HandlerTestSuite) TestHandlerErrorResponse() { func (s *HandlerTestSuite) TestHandlerErrorResponse() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
w.WriteHeader(404)
w.Write([]byte("Not Found"))
}))
defer ts.Close()
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result() s.Require().Equal(http.StatusNotFound, res.StatusCode)
s.Require().Equal(404, res.StatusCode)
} }
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
s.config().CookiePassthrough = true s.config().CookiePassthrough = true
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().
// Verify cookies are passed through SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
cookie, cerr := r.Cookie("test_cookie") SetHook(func(r *http.Request, rw http.ResponseWriter) {
if cerr == nil { // Verify cookies are passed through
s.Equal("test_value", cookie.Value) cookie, cerr := r.Cookie("test_cookie")
} if cerr == nil {
s.Equal("test_value", cookie.Value)
w.Header().Set(httpheaders.ContentType, "image/png") }
w.WriteHeader(200) }).SetBody(data)
w.Write(data)
}))
defer ts.Close()
h := make(http.Header) h := make(http.Header)
h.Set(httpheaders.Cookie, "test_cookie=test_value") h.Set(httpheaders.Cookie, "test_cookie=test_value")
rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) res := s.execute("", h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
} }
// TestHandlerCanonicalHeader tests that the canonical header is set correctly // TestHandlerCanonicalHeader tests that the canonical header is set correctly
func (s *HandlerTestSuite) TestHandlerCanonicalHeader() { func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
data := s.readTestFile("test1.png") data := s.testData.Read("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
w.Header().Set(httpheaders.ContentType, "image/png")
w.WriteHeader(200)
w.Write(data)
}))
defer ts.Close()
for _, sc := range []bool{true, false} { for _, sc := range []bool{true, false} {
s.rwConf().SetCanonicalHeader = sc s.rwConf().SetCanonicalHeader = sc
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := s.execute("", nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode) s.Require().Equal(200, res.StatusCode)
if sc { if sc {
s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL)) s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
} else { } else {
s.Require().Empty(res.Header.Get(httpheaders.Link)) s.Require().Empty(res.Header.Get(httpheaders.Link))
} }

View File

@@ -6,17 +6,13 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"os"
"strconv" "strconv"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/fetcher" "github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/ierrors"
@@ -25,88 +21,70 @@ import (
) )
type ImageDataTestSuite struct { type ImageDataTestSuite struct {
suite.Suite testutil.LazySuite
server *httptest.Server fetcherCfg testutil.LazyObj[*fetcher.Config]
factory testutil.LazyObj[*Factory]
testServer testutil.LazyTestServer
status int data []byte
data []byte
header http.Header
check func(*http.Request)
factory *Factory
defaultData []byte
} }
func (s *ImageDataTestSuite) SetupSuite() { func (s *ImageDataTestSuite) SetupSuite() {
config.Reset() s.data = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
config.ClientKeepAliveTimeout = 0
f, err := os.Open("../testdata/test1.jpg") s.fetcherCfg, _ = testutil.NewLazySuiteObj(
s.Require().NoError(err) s,
defer f.Close() func() (*fetcher.Config, error) {
c := fetcher.NewDefaultConfig()
c.Transport.HTTP.AllowLoopbackSourceAddresses = true
c.Transport.HTTP.ClientKeepAliveTimeout = 0
data, err := io.ReadAll(f) return &c, nil
s.Require().NoError(err) },
)
s.defaultData = data s.factory, _ = testutil.NewLazySuiteObj(
s,
func() (*Factory, error) {
fetcher, err := fetcher.New(s.fetcherCfg())
if err != nil {
return nil, err
}
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return NewFactory(fetcher), nil
if s.check != nil { },
s.check(r) )
}
httpheaders.CopyAll(s.header, rw.Header(), true) s.testServer, _ = testutil.NewLazySuiteTestServer(
s,
func(srv *testutil.TestServer) error {
// Default headers and body for 200 OK response
srv.SetHeaders(
httpheaders.ContentType, "image/jpeg",
httpheaders.ContentLength, strconv.Itoa(len(s.data)),
).SetBody(s.data)
data := s.data return nil
if data == nil { },
data = s.defaultData )
}
rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
rw.WriteHeader(s.status)
rw.Write(data)
}))
c, err := fetcher.LoadConfigFromEnv(nil)
s.Require().NoError(err)
fetcher, err := fetcher.New(c)
s.Require().NoError(err)
s.factory = NewFactory(fetcher)
} }
func (s *ImageDataTestSuite) TearDownSuite() { func (s *ImageDataTestSuite) SetupSubTest() {
s.server.Close() // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
} s.ResetLazyObjects()
func (s *ImageDataTestSuite) SetupTest() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
s.status = http.StatusOK
s.data = nil
s.check = nil
s.header = http.Header{}
s.header.Set("Content-Type", "image/jpeg")
} }
func (s *ImageDataTestSuite) TestDownloadStatusOK() { func (s *ImageDataTestSuite) TestDownloadStatusOK() {
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(imgdata) s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format()) s.Require().Equal(imagetype.JPEG, imgdata.Format())
} }
func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() { func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
s.status = http.StatusPartialContent
testCases := []struct { testCases := []struct {
name string name string
contentRange string contentRange string
@@ -114,17 +92,17 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
}{ }{
{ {
name: "Full Content-Range", name: "Full Content-Range",
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-1, len(s.defaultData)), contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-1, len(s.data)),
expectErr: false, expectErr: false,
}, },
{ {
name: "Partial Content-Range, early end", name: "Partial Content-Range, early end",
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-2, len(s.defaultData)), contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-2, len(s.data)),
expectErr: true, expectErr: true,
}, },
{ {
name: "Partial Content-Range, late start", name: "Partial Content-Range, late start",
contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.defaultData)-1, len(s.defaultData)), contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.data)-1, len(s.data)),
expectErr: true, expectErr: true,
}, },
{ {
@@ -139,39 +117,41 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
}, },
{ {
name: "Unknown Content-Range range", name: "Unknown Content-Range range",
contentRange: fmt.Sprintf("bytes */%d", len(s.defaultData)), contentRange: fmt.Sprintf("bytes */%d", len(s.data)),
expectErr: true, expectErr: true,
}, },
{ {
name: "Unknown Content-Range size, full range", name: "Unknown Content-Range size, full range",
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-1), contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-1),
expectErr: false, expectErr: false,
}, },
{ {
name: "Unknown Content-Range size, early end", name: "Unknown Content-Range size, early end",
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-2), contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-2),
expectErr: true, expectErr: true,
}, },
{ {
name: "Unknown Content-Range size, late start", name: "Unknown Content-Range size, late start",
contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.defaultData)-1), contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.data)-1),
expectErr: true, expectErr: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
s.header.Set("Content-Range", tc.contentRange) s.testServer().
SetHeaders(httpheaders.ContentRange, tc.contentRange).
SetStatusCode(http.StatusPartialContent)
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
if tc.expectErr { if tc.expectErr {
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(http.StatusNotFound, ierrors.Wrap(err, 0).StatusCode())
} else { } else {
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(imgdata) s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format()) s.Require().Equal(imagetype.JPEG, imgdata.Format())
} }
}) })
@@ -179,11 +159,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
} }
func (s *ImageDataTestSuite) TestDownloadStatusNotFound() { func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
s.status = http.StatusNotFound s.testServer().
s.data = []byte("Not Found") SetStatusCode(http.StatusNotFound).
s.header.Set("Content-Type", "text/plain") SetBody([]byte("Not Found")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -191,11 +172,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
} }
func (s *ImageDataTestSuite) TestDownloadStatusForbidden() { func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
s.status = http.StatusForbidden s.testServer().
s.data = []byte("Forbidden") SetStatusCode(http.StatusForbidden).
s.header.Set("Content-Type", "text/plain") SetBody([]byte("Forbidden")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -203,11 +185,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
} }
func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() { func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() {
s.status = http.StatusInternalServerError s.testServer().
s.data = []byte("Internal Server Error") SetStatusCode(http.StatusInternalServerError).
s.header.Set("Content-Type", "text/plain") SetBody([]byte("Internal Server Error")).
SetHeaders(httpheaders.ContentType, "text/plain")
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -221,7 +204,7 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
serverURL := fmt.Sprintf("http://%s", l.Addr().String()) serverURL := fmt.Sprintf("http://%s", l.Addr().String())
imgdata, _, err := s.factory.DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
@@ -229,19 +212,19 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
} }
func (s *ImageDataTestSuite) TestDownloadInvalidImage() { func (s *ImageDataTestSuite) TestDownloadInvalidImage() {
s.data = []byte("invalid") s.testServer().SetBody([]byte("invalid"))
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(http.StatusUnprocessableEntity, ierrors.Wrap(err, 0).StatusCode())
s.Require().Nil(imgdata) s.Require().Nil(imgdata)
} }
func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() { func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
config.AllowLoopbackSourceAddresses = false s.fetcherCfg().Transport.HTTP.AllowLoopbackSourceAddresses = false
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
@@ -249,11 +232,10 @@ func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
} }
func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() { func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() {
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{ imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{
MaxSrcFileSize: 1, MaxSrcFileSize: 1,
}) })
fmt.Println(err)
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode()) s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
s.Require().Nil(imgdata) s.Require().Nil(imgdata)
@@ -263,39 +245,43 @@ func (s *ImageDataTestSuite) TestDownloadGzip() {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
enc := gzip.NewWriter(buf) enc := gzip.NewWriter(buf)
_, err := enc.Write(s.defaultData) _, err := enc.Write(s.data)
s.Require().NoError(err) s.Require().NoError(err)
err = enc.Close() err = enc.Close()
s.Require().NoError(err) s.Require().NoError(err)
s.data = buf.Bytes() s.testServer().
s.header.Set("Content-Encoding", "gzip") SetBody(buf.Bytes()).
SetHeaders(
httpheaders.ContentEncoding, "gzip",
httpheaders.ContentLength, strconv.Itoa(buf.Len()), // Update Content-Length
)
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{}) imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(imgdata) s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format()) s.Require().Equal(imagetype.JPEG, imgdata.Format())
} }
func (s *ImageDataTestSuite) TestFromFile() { func (s *ImageDataTestSuite) TestFromFile() {
imgdata, err := s.factory.NewFromPath("../testdata/test1.jpg") imgdata, err := s.factory().NewFromPath("../testdata/test1.jpg")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(imgdata) s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format()) s.Require().Equal(imagetype.JPEG, imgdata.Format())
} }
func (s *ImageDataTestSuite) TestFromBase64() { func (s *ImageDataTestSuite) TestFromBase64() {
b64 := base64.StdEncoding.EncodeToString(s.defaultData) b64 := base64.StdEncoding.EncodeToString(s.data)
imgdata, err := s.factory.NewFromBase64(b64) imgdata, err := s.factory().NewFromBase64(b64)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(imgdata) s.Require().NotNil(imgdata)
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader())) s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
s.Require().Equal(imagetype.JPEG, imgdata.Format()) s.Require().Equal(imagetype.JPEG, imgdata.Format())
} }

View File

@@ -27,10 +27,7 @@ type ProcessingHandlerTestSuite struct {
func (s *ProcessingHandlerTestSuite) SetupTest() { func (s *ProcessingHandlerTestSuite) SetupTest() {
config.Reset() // We reset config only at the start of each test config.Reset() // We reset config only at the start of each test
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
// NOTE: This must be moved to security config
config.AllowLoopbackSourceAddresses = true
// NOTE: end note
} }
func (s *ProcessingHandlerTestSuite) SetupSubTest() { func (s *ProcessingHandlerTestSuite) SetupSubTest() {
@@ -142,13 +139,13 @@ func (s *ProcessingHandlerTestSuite) TestSourceNetworkValidation() {
// We wrap this in a subtest to reset s.router() // We wrap this in a subtest to reset s.router()
s.Run("AllowLoopbackSourceAddressesTrue", func() { s.Run("AllowLoopbackSourceAddressesTrue", func() {
config.AllowLoopbackSourceAddresses = true s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
res := s.GET(url) res := s.GET(url)
s.Require().Equal(http.StatusOK, res.StatusCode) s.Require().Equal(http.StatusOK, res.StatusCode)
}) })
s.Run("AllowLoopbackSourceAddressesFalse", func() { s.Run("AllowLoopbackSourceAddressesFalse", func() {
config.AllowLoopbackSourceAddresses = false s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = false
res := s.GET(url) res := s.GET(url)
s.Require().Equal(http.StatusNotFound, res.StatusCode) s.Require().Equal(http.StatusNotFound, res.StatusCode)
}) })
@@ -256,7 +253,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
} }
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() { func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() {
config.CacheControlPassthrough = true s.Config().Server.ResponseWriter.CacheControlPassthrough = true
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat)) rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat))
@@ -290,7 +287,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughDisabled() {
} }
func (s *ProcessingHandlerTestSuite) TestETagDisabled() { func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
config.ETagEnabled = false s.Config().Handlers.Processing.ETagEnabled = false
res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png") res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png")
@@ -299,7 +296,7 @@ func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
} }
func (s *ProcessingHandlerTestSuite) TestETagDataMatch() { func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
config.ETagEnabled = true s.Config().Handlers.Processing.ETagEnabled = true
etag := `"loremipsumdolor"` etag := `"loremipsumdolor"`
@@ -321,7 +318,8 @@ func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
} }
func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() { func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
config.LastModifiedEnabled = true s.Config().Handlers.Processing.LastModifiedEnabled = true
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
rw.WriteHeader(200) rw.WriteHeader(200)
@@ -335,7 +333,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
} }
func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() { func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
config.LastModifiedEnabled = false s.Config().Handlers.Processing.LastModifiedEnabled = false
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT") rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
rw.WriteHeader(200) rw.WriteHeader(200)
@@ -349,7 +347,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
} }
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() { func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() {
config.LastModifiedEnabled = false s.Config().Handlers.Processing.LastModifiedEnabled = false
data := s.TestData.Read("test1.png") data := s.TestData.Read("test1.png")
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -368,7 +366,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedD
} }
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() { func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() {
config.LastModifiedEnabled = true s.Config().Handlers.Processing.LastModifiedEnabled = true
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince) modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
@@ -386,7 +384,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedE
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() { func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() {
data := s.TestData.Read("test1.png") data := s.TestData.Read("test1.png")
config.LastModifiedEnabled = false s.Config().Handlers.Processing.LastModifiedEnabled = false
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince) modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
s.Empty(modifiedSince) s.Empty(modifiedSince)

View File

@@ -1082,10 +1082,10 @@ func applyURLOption(po *ProcessingOptions, name string, args []string, usedPrese
} }
func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error { func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error {
allowAll = allowAll || len(config.AllowedProcessiongOptions) == 0 allowAll = allowAll || len(config.AllowedProcessingOptions) == 0
for _, opt := range options { for _, opt := range options {
if !allowAll && !slices.Contains(config.AllowedProcessiongOptions, opt.Name) { if !allowAll && !slices.Contains(config.AllowedProcessingOptions, opt.Name) {
return newForbiddenOptionError("processing", opt.Name) return newForbiddenOptionError("processing", opt.Name)
} }

View File

@@ -646,7 +646,7 @@ func (s *ProcessingOptionsTestSuite) TestParseBase64URLOnlyPresets() {
} }
func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() { func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() {
config.AllowedProcessiongOptions = []string{"w", "h", "pr"} config.AllowedProcessingOptions = []string{"w", "h", "pr"}
presets["test1"] = urlOptions{ presets["test1"] = urlOptions{
urlOption{Name: "blur", Args: []string{"0.2"}}, urlOption{Name: "blur", Args: []string{"0.2"}},

View File

@@ -13,7 +13,6 @@ type (
ImageResolutionError string ImageResolutionError string
SecurityOptionsError struct{} SecurityOptionsError struct{}
SourceURLError string SourceURLError string
SourceAddressError string
) )
func newSignatureError(msg string) error { func newSignatureError(msg string) error {
@@ -75,15 +74,3 @@ func newSourceURLError(imageURL string) error {
} }
func (e SourceURLError) Error() string { return string(e) } func (e SourceURLError) Error() string { return string(e) }
func newSourceAddressError(msg string) error {
return ierrors.Wrap(
SourceAddressError(msg),
1,
ierrors.WithStatusCode(http.StatusNotFound),
ierrors.WithPublicMessage("Invalid source URL"),
ierrors.WithShouldReport(false),
)
}
func (e SourceAddressError) Error() string { return string(e) }

View File

@@ -1,9 +1,6 @@
package security package security
import ( import (
"fmt"
"net"
"github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config"
) )
@@ -20,29 +17,3 @@ func VerifySourceURL(imageURL string) error {
return newSourceURLError(imageURL) return newSourceURLError(imageURL)
} }
func VerifySourceNetwork(addr string) error {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
}
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
}
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
}
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
}
return nil
}

View File

@@ -51,7 +51,7 @@ func NewLazySuiteObj[T any](
// Get the [LazySuite] instance // Get the [LazySuite] instance
lazy := s.Lazy() lazy := s.Lazy()
// Create the [LazyObj] instance // Create the [LazyObj] instance
obj, cancel := NewLazyObj(lazy, newFn, dropFn...) obj, cancel := newLazyObj(lazy, newFn, dropFn...)
// Add cleanup function to the resets list // Add cleanup function to the resets list
lazy.resets = append(lazy.resets, cancel) lazy.resets = append(lazy.resets, cancel)

View File

@@ -23,10 +23,10 @@ type LazyObjNew[T any] func() (T, error)
// If the object was not yet initialized, the callback is not called. // If the object was not yet initialized, the callback is not called.
type LazyObjDrop[T any] func(T) error type LazyObjDrop[T any] func(T) error
// NewLazyObj creates a new [LazyObj] that initializes the object on the first call. // newLazyObj creates a new [LazyObj] that initializes the object on the first call.
// It returns a function that can be called to get the object and a cancel function // It returns a function that can be called to get the object and a cancel function
// that can be called to reset the object. // that can be called to reset the object.
func NewLazyObj[T any]( func newLazyObj[T any](
s LazyObjT, s LazyObjT,
newFn LazyObjNew[T], newFn LazyObjNew[T],
dropFn ...LazyObjDrop[T], dropFn ...LazyObjDrop[T],

View File

@@ -26,9 +26,6 @@ type TestDataProvider struct {
// New creates a new TestDataProvider // New creates a new TestDataProvider
func NewTestDataProvider(t TestDataProviderT) *TestDataProvider { func NewTestDataProvider(t TestDataProviderT) *TestDataProvider {
// if h, ok := t.(interface{ Helper() }); ok {
// h.Helper()
// }
t().Helper() t().Helper()
path, err := findProjectRoot() path, err := findProjectRoot()

122
testutil/test_server.go Normal file
View File

@@ -0,0 +1,122 @@
package testutil
import (
"context"
"net/http"
"net/http/httptest"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/stretchr/testify/require"
)
// TestServerHookFunc is a function type for in-request hooks
type TestServerHookFunc func(r *http.Request, rw http.ResponseWriter)
// Sugar alias
type LazyTestServer = LazyObj[*TestServer]
// TestServer is a syntax sugar wrapper over httptest.Server
type TestServer struct {
testServer *httptest.Server
status int
data []byte
header http.Header
hook TestServerHookFunc
}
// NewLazySuiteTestServer creates a lazy TestServer object for use in test suites
func NewLazySuiteTestServer(
l LazySuiteFrom,
init ...func(*TestServer) error,
) (LazyObj[*TestServer], context.CancelFunc) {
return NewLazySuiteObj(
l,
func() (*TestServer, error) {
s := NewTestServer()
if len(init) > 0 {
for _, fn := range init {
if fn == nil {
continue
}
err := fn(s)
require.NoError(l.Lazy().T(), err, "Failed to reset test server")
}
}
return s, nil
},
func(s *TestServer) error {
s.Close()
return nil
},
)
}
// New creates and starts new http.TestServer
func NewTestServer() *TestServer {
ts := &TestServer{
status: http.StatusOK,
header: make(http.Header),
data: nil,
hook: nil,
}
return ts.start()
}
// SetStatusCode sets the status code that will be returned by the server
func (s *TestServer) SetStatusCode(status int) *TestServer {
s.status = status
return s
}
// SetBody sets the body that will be returned by the server
func (s *TestServer) SetBody(data []byte) *TestServer {
s.data = data
return s
}
// WithHeader adds headers that will be returned by the server.
// Odd arguments are treated as keys, even arguments as values.
func (s *TestServer) SetHeaders(kv ...string) *TestServer {
for i := 0; i+1 < len(kv); i += 2 {
key := kv[i]
value := kv[i+1]
s.header.Set(key, value)
}
return s
}
// SetHook sets a function that will be called on each request. It is called
// after headsers are set, but before status and body are written.
func (s *TestServer) SetHook(f TestServerHookFunc) *TestServer {
s.hook = f
return s
}
// Start starts the server
func (s *TestServer) start() *TestServer {
s.testServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpheaders.CopyAll(s.header, w.Header(), true)
if s.hook != nil {
s.hook(r, w)
}
w.WriteHeader(s.status)
w.Write(s.data)
}))
return s
}
// Close stops the server
func (s *TestServer) Close() {
s.testServer.Close()
}
// URL returns the server URL
func (s *TestServer) URL() string {
return s.testServer.URL
}