mirror of
https://github.com/imgproxy/imgproxy.git
synced 2025-09-27 12:07:59 +02:00
TestServer, AllowNetworks -> http.Transport
This commit is contained in:
@@ -3,77 +3,56 @@ package auximageprovider
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||
)
|
||||
|
||||
type ImageProviderTestSuite struct {
|
||||
suite.Suite
|
||||
testutil.LazySuite
|
||||
|
||||
server *httptest.Server
|
||||
testData []byte
|
||||
testDataB64 string
|
||||
|
||||
// Server state
|
||||
status int
|
||||
data []byte
|
||||
header http.Header
|
||||
testServer testutil.LazyTestServer
|
||||
idf *imagedata.Factory
|
||||
}
|
||||
|
||||
func (s *ImageProviderTestSuite) SetupSuite() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
s.testData = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
|
||||
s.testDataB64 = base64.StdEncoding.EncodeToString(s.testData)
|
||||
|
||||
// Load test image data
|
||||
f, err := os.Open("../testdata/test1.jpg")
|
||||
s.Require().NoError(err)
|
||||
defer f.Close()
|
||||
fc := fetcher.NewDefaultConfig()
|
||||
fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
f, err := fetcher.New(&fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.testData = data
|
||||
s.testDataB64 = base64.StdEncoding.EncodeToString(data)
|
||||
s.idf = imagedata.NewFactory(f)
|
||||
|
||||
// Create test server
|
||||
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
for k, vv := range s.header {
|
||||
for _, v := range vv {
|
||||
rw.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
s.testServer, _ = testutil.NewLazySuiteTestServer(
|
||||
s,
|
||||
func(srv *testutil.TestServer) error {
|
||||
srv.SetHeaders(
|
||||
httpheaders.ContentType, "image/jpeg",
|
||||
httpheaders.ContentLength, strconv.Itoa(len(s.testData)),
|
||||
).SetBody(s.testData)
|
||||
|
||||
data := s.data
|
||||
if data == nil {
|
||||
data = s.testData
|
||||
}
|
||||
|
||||
rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data)))
|
||||
rw.WriteHeader(s.status)
|
||||
rw.Write(data)
|
||||
}))
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ImageProviderTestSuite) TearDownSuite() {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
func (s *ImageProviderTestSuite) SetupTest() {
|
||||
s.status = http.StatusOK
|
||||
s.data = nil
|
||||
s.header = http.Header{}
|
||||
s.header.Set(httpheaders.ContentType, "image/jpeg")
|
||||
func (s *ImageProviderTestSuite) SetupSubTest() {
|
||||
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
|
||||
s.ResetLazyObjects()
|
||||
}
|
||||
|
||||
// Helper function to read data from ImageData
|
||||
@@ -114,7 +93,7 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
|
||||
},
|
||||
{
|
||||
name: "URL",
|
||||
config: &StaticConfig{URL: s.server.URL},
|
||||
config: &StaticConfig{URL: s.testServer().URL()},
|
||||
validateFunc: func(provider Provider) {
|
||||
s.Equal(s.testData, s.readImageData(provider))
|
||||
},
|
||||
@@ -149,10 +128,12 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
|
||||
},
|
||||
{
|
||||
name: "HeadersPassedThrough",
|
||||
config: &StaticConfig{URL: s.server.URL},
|
||||
config: &StaticConfig{URL: s.testServer().URL()},
|
||||
setupFunc: func() {
|
||||
s.header.Set("X-Custom-Header", "test-value")
|
||||
s.header.Set(httpheaders.CacheControl, "max-age=3600")
|
||||
s.testServer().SetHeaders(
|
||||
"X-Custom-Header", "test-value",
|
||||
httpheaders.CacheControl, "max-age=3600",
|
||||
)
|
||||
},
|
||||
validateFunc: func(provider Provider) {
|
||||
imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{})
|
||||
@@ -167,19 +148,13 @@ func (s *ImageProviderTestSuite) TestNewProvider() {
|
||||
},
|
||||
}
|
||||
|
||||
fc := fetcher.NewDefaultConfig()
|
||||
f, err := fetcher.New(&fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
idf := imagedata.NewFactory(f)
|
||||
|
||||
for _, tt := range tests {
|
||||
s.T().Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupFunc != nil {
|
||||
tt.setupFunc()
|
||||
}
|
||||
|
||||
provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", idf)
|
||||
provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image", s.idf)
|
||||
|
||||
if tt.expectError {
|
||||
s.Require().Error(err)
|
||||
|
@@ -58,7 +58,7 @@ var (
|
||||
PngUnlimited bool
|
||||
SvgUnlimited bool
|
||||
MaxResultDimension int
|
||||
AllowedProcessiongOptions []string
|
||||
AllowedProcessingOptions []string
|
||||
AllowSecurityOptions bool
|
||||
|
||||
JpegProgressive bool
|
||||
@@ -267,7 +267,7 @@ func Reset() {
|
||||
PngUnlimited = false
|
||||
SvgUnlimited = false
|
||||
MaxResultDimension = 0
|
||||
AllowedProcessiongOptions = make([]string, 0)
|
||||
AllowedProcessingOptions = make([]string, 0)
|
||||
AllowSecurityOptions = false
|
||||
|
||||
JpegProgressive = false
|
||||
@@ -502,7 +502,7 @@ func Configure() error {
|
||||
configurators.Bool(&SvgUnlimited, "IMGPROXY_SVG_UNLIMITED")
|
||||
|
||||
configurators.Int(&MaxResultDimension, "IMGPROXY_MAX_RESULT_DIMENSION")
|
||||
configurators.StringSlice(&AllowedProcessiongOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
|
||||
configurators.StringSlice(&AllowedProcessingOptions, "IMGPROXY_ALLOWED_PROCESSING_OPTIONS")
|
||||
|
||||
configurators.Bool(&AllowSecurityOptions, "IMGPROXY_ALLOW_SECURITY_OPTIONS")
|
||||
|
||||
|
@@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher/transport/generichttp"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/security"
|
||||
)
|
||||
|
||||
const msgSourceImageIsUnreachable = "Source image is unreachable"
|
||||
@@ -157,13 +158,21 @@ func (e NotModifiedError) Headers() http.Header {
|
||||
func WrapError(err error) error {
|
||||
isTimeout := false
|
||||
|
||||
var secArrdErr security.SourceAddressError
|
||||
var secArrdErr generichttp.SourceAddressError
|
||||
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
isTimeout = true
|
||||
case errors.Is(err, context.Canceled):
|
||||
return newImageRequestCanceledError(err)
|
||||
case err == io.ErrUnexpectedEOF:
|
||||
return ierrors.Wrap(
|
||||
newImageRequestError(err),
|
||||
1,
|
||||
ierrors.WithPublicMessage("source image is corrupted"),
|
||||
ierrors.WithShouldReport(false),
|
||||
ierrors.WithStatusCode(http.StatusUnprocessableEntity),
|
||||
)
|
||||
case errors.As(err, &secArrdErr):
|
||||
return ierrors.Wrap(
|
||||
err,
|
||||
|
@@ -10,15 +10,21 @@ import (
|
||||
|
||||
// Config holds the configuration for the generic HTTP transport
|
||||
type Config struct {
|
||||
ClientKeepAliveTimeout time.Duration
|
||||
IgnoreSslVerification bool
|
||||
ClientKeepAliveTimeout time.Duration
|
||||
IgnoreSslVerification bool
|
||||
AllowLoopbackSourceAddresses bool
|
||||
AllowLinkLocalSourceAddresses bool
|
||||
AllowPrivateSourceAddresses bool
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a new default configuration for the generic HTTP transport
|
||||
func NewDefaultConfig() Config {
|
||||
return Config{
|
||||
ClientKeepAliveTimeout: 90 * time.Second,
|
||||
IgnoreSslVerification: false,
|
||||
ClientKeepAliveTimeout: 90 * time.Second,
|
||||
IgnoreSslVerification: false,
|
||||
AllowLoopbackSourceAddresses: false,
|
||||
AllowLinkLocalSourceAddresses: false,
|
||||
AllowPrivateSourceAddresses: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +34,9 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||
|
||||
c.ClientKeepAliveTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
|
||||
c.IgnoreSslVerification = config.IgnoreSslVerification
|
||||
c.AllowLinkLocalSourceAddresses = config.AllowLinkLocalSourceAddresses
|
||||
c.AllowLoopbackSourceAddresses = config.AllowLoopbackSourceAddresses
|
||||
c.AllowPrivateSourceAddresses = config.AllowPrivateSourceAddresses
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
23
fetcher/transport/generichttp/errors.go
Normal file
23
fetcher/transport/generichttp/errors.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package generichttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
)
|
||||
|
||||
type (
|
||||
SourceAddressError string
|
||||
)
|
||||
|
||||
func newSourceAddressError(msg string) error {
|
||||
return ierrors.Wrap(
|
||||
SourceAddressError(msg),
|
||||
1,
|
||||
ierrors.WithStatusCode(http.StatusNotFound),
|
||||
ierrors.WithPublicMessage("Invalid source URL"),
|
||||
ierrors.WithShouldReport(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (e SourceAddressError) Error() string { return string(e) }
|
@@ -3,12 +3,12 @@ package generichttp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/security"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
|
||||
|
||||
if verifyNetworks {
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return security.VerifySourceNetwork(address)
|
||||
return verifySourceNetwork(address, config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,3 +66,29 @@ func New(verifyNetworks bool, config *Config) (*http.Transport, error) {
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func verifySourceNetwork(addr string, config *Config) error {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
|
||||
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
|
||||
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
|
||||
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,9 +1,8 @@
|
||||
package security
|
||||
package generichttp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -100,24 +99,14 @@ func TestVerifySourceNetwork(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Backup original config
|
||||
originalLoopback := config.AllowLoopbackSourceAddresses
|
||||
originalLinkLocal := config.AllowLinkLocalSourceAddresses
|
||||
originalPrivate := config.AllowPrivateSourceAddresses
|
||||
|
||||
// Restore original config after test
|
||||
defer func() {
|
||||
config.AllowLoopbackSourceAddresses = originalLoopback
|
||||
config.AllowLinkLocalSourceAddresses = originalLinkLocal
|
||||
config.AllowPrivateSourceAddresses = originalPrivate
|
||||
}()
|
||||
config := NewDefaultConfig()
|
||||
|
||||
// Override config for the test
|
||||
config.AllowLoopbackSourceAddresses = tc.allowLoopback
|
||||
config.AllowLinkLocalSourceAddresses = tc.allowLinkLocal
|
||||
config.AllowPrivateSourceAddresses = tc.allowPrivate
|
||||
|
||||
err := VerifySourceNetwork(tc.addr)
|
||||
err := verifySourceNetwork(tc.addr, &config)
|
||||
|
||||
if tc.expectErr {
|
||||
require.Error(t, err)
|
@@ -1 +0,0 @@
|
||||
package processing
|
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -22,23 +21,24 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
testDataPath = "../../testdata"
|
||||
)
|
||||
|
||||
type HandlerTestSuite struct {
|
||||
testutil.LazySuite
|
||||
|
||||
testData *testutil.TestDataProvider
|
||||
|
||||
rwConf testutil.LazyObj[*responsewriter.Config]
|
||||
rwFactory testutil.LazyObj[*responsewriter.Factory]
|
||||
|
||||
config testutil.LazyObj[*Config]
|
||||
handler testutil.LazyObj[*Handler]
|
||||
|
||||
testServer testutil.LazyTestServer
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) SetupSuite() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
|
||||
s.testData = testutil.NewTestDataProvider(s.T)
|
||||
|
||||
s.rwConf, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
@@ -67,6 +67,7 @@ func (s *HandlerTestSuite) SetupSuite() {
|
||||
s,
|
||||
func() (*Handler, error) {
|
||||
fc := fetcher.NewDefaultConfig()
|
||||
fc.Transport.HTTP.AllowLoopbackSourceAddresses = true
|
||||
|
||||
fetcher, err := fetcher.New(&fc)
|
||||
s.Require().NoError(err)
|
||||
@@ -75,36 +76,27 @@ func (s *HandlerTestSuite) SetupSuite() {
|
||||
},
|
||||
)
|
||||
|
||||
s.testServer, _ = testutil.NewLazySuiteTestServer(s)
|
||||
|
||||
// Silence logs during tests
|
||||
logrus.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) TearDownSuite() {
|
||||
config.Reset()
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) SetupTest() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) SetupSubTest() {
|
||||
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
|
||||
s.ResetLazyObjects()
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
||||
data, err := os.ReadFile(filepath.Join(testDataPath, name))
|
||||
s.Require().NoError(err)
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) execute(
|
||||
imageURL string,
|
||||
header http.Header,
|
||||
po *options.ProcessingOptions,
|
||||
) *httptest.ResponseRecorder {
|
||||
) *http.Response {
|
||||
imageURL = s.testServer().URL() + imageURL
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
httpheaders.CopyAll(header, req.Header, true)
|
||||
|
||||
@@ -115,51 +107,42 @@ func (s *HandlerTestSuite) execute(
|
||||
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
|
||||
s.Require().NoError(err)
|
||||
|
||||
return rw
|
||||
return rw.Result()
|
||||
}
|
||||
|
||||
// TestHandlerBasicRequest checks basic streaming request
|
||||
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
|
||||
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
res := s.execute("", nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
|
||||
|
||||
// Verify we get the original image data
|
||||
actual := rw.Body.Bytes()
|
||||
actual, err := io.ReadAll(res.Body)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(data, actual)
|
||||
}
|
||||
|
||||
// TestHandlerResponseHeadersPassthrough checks that original response headers are
|
||||
// passed through to the client
|
||||
func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
contentLength := len(data)
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.Header().Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
|
||||
w.Header().Set(httpheaders.AcceptRanges, "bytes")
|
||||
w.Header().Set(httpheaders.Etag, "etag")
|
||||
w.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetHeaders(
|
||||
httpheaders.ContentType, "image/png",
|
||||
httpheaders.ContentLength, strconv.Itoa(contentLength),
|
||||
httpheaders.AcceptRanges, "bytes",
|
||||
httpheaders.Etag, "etag",
|
||||
httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT",
|
||||
).SetBody(data)
|
||||
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
res := s.execute("", nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
s.Require().Equal("image/png", res.Header.Get(httpheaders.ContentType))
|
||||
s.Require().Equal(strconv.Itoa(contentLength), res.Header.Get(httpheaders.ContentLength))
|
||||
@@ -172,42 +155,34 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
|
||||
// to the server
|
||||
func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
|
||||
etag := `"test-etag-123"`
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify that If-None-Match header is passed through
|
||||
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
|
||||
s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
|
||||
s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
|
||||
|
||||
w.Header().Set(httpheaders.Etag, etag)
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().
|
||||
SetBody(data).
|
||||
SetHeaders(httpheaders.Etag, etag).
|
||||
SetHook(func(r *http.Request, rw http.ResponseWriter) {
|
||||
// Verify that If-None-Match header is passed through
|
||||
s.Equal(etag, r.Header.Get(httpheaders.IfNoneMatch))
|
||||
s.Equal("gzip", r.Header.Get(httpheaders.AcceptEncoding))
|
||||
s.Equal("bytes=*", r.Header.Get(httpheaders.Range))
|
||||
})
|
||||
|
||||
h := make(http.Header)
|
||||
h.Set(httpheaders.IfNoneMatch, etag)
|
||||
h.Set(httpheaders.AcceptEncoding, "gzip")
|
||||
h.Set(httpheaders.Range, "bytes=*")
|
||||
|
||||
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||
res := s.execute("", h, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
s.Require().Equal(etag, res.Header.Get(httpheaders.Etag))
|
||||
}
|
||||
|
||||
// TestHandlerContentDisposition checks that Content-Disposition header is set correctly
|
||||
func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
|
||||
|
||||
po := &options.ProcessingOptions{
|
||||
Filename: "custom_name",
|
||||
@@ -215,10 +190,8 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
||||
}
|
||||
|
||||
// Use a URL with a .png extension to help content disposition logic
|
||||
imageURL := ts.URL + "/test.png"
|
||||
rw := s.execute(imageURL, nil, po)
|
||||
res := s.execute("/test.png", nil, po)
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "custom_name.png")
|
||||
s.Require().Contains(res.Header.Get(httpheaders.ContentDisposition), "attachment")
|
||||
@@ -229,7 +202,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
type testCase struct {
|
||||
name string
|
||||
cacheControlPassthrough bool
|
||||
setupOriginHeaders func(http.ResponseWriter)
|
||||
setupOriginHeaders func()
|
||||
timestampOffset *time.Duration // nil for no timestamp, otherwise the offset from now
|
||||
expectedStatusCode int
|
||||
validate func(*testing.T, *http.Response)
|
||||
@@ -250,8 +223,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "Passthrough",
|
||||
cacheControlPassthrough: true,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
|
||||
setupOriginHeaders: func() {
|
||||
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
|
||||
},
|
||||
timestampOffset: nil,
|
||||
expectedStatusCode: 200,
|
||||
@@ -263,8 +236,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "ExpiresPassthrough",
|
||||
cacheControlPassthrough: true,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
|
||||
setupOriginHeaders: func() {
|
||||
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
|
||||
},
|
||||
timestampOffset: nil,
|
||||
expectedStatusCode: 200,
|
||||
@@ -278,8 +251,8 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "PassthroughDisabled",
|
||||
cacheControlPassthrough: false,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
w.Header().Set(httpheaders.CacheControl, "max-age=3600, public")
|
||||
setupOriginHeaders: func() {
|
||||
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=3600, public")
|
||||
},
|
||||
timestampOffset: nil,
|
||||
expectedStatusCode: 200,
|
||||
@@ -291,7 +264,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "WithProcessingOptionsExpires",
|
||||
cacheControlPassthrough: false,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {}, // No origin headers
|
||||
timestampOffset: &oneHour,
|
||||
expectedStatusCode: 200,
|
||||
validate: func(t *testing.T, res *http.Response) {
|
||||
@@ -303,9 +275,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "ProcessingOptionsOverridesOrigin",
|
||||
cacheControlPassthrough: true,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
setupOriginHeaders: func() {
|
||||
// Origin has a longer cache time
|
||||
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
|
||||
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
|
||||
},
|
||||
timestampOffset: &thirtyMinutes,
|
||||
expectedStatusCode: 200,
|
||||
@@ -318,10 +290,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "BothHeadersPassthroughEnabled",
|
||||
cacheControlPassthrough: true,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
setupOriginHeaders: func() {
|
||||
// Origin has both Cache-Control and Expires headers
|
||||
w.Header().Set(httpheaders.CacheControl, "max-age=1800, public")
|
||||
w.Header().Set(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
|
||||
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=1800, public")
|
||||
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(oneHour).UTC().Format(http.TimeFormat))
|
||||
},
|
||||
timestampOffset: nil,
|
||||
expectedStatusCode: 200,
|
||||
@@ -336,10 +308,10 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "ProcessingOptionsOverridesBothOriginHeaders",
|
||||
cacheControlPassthrough: true,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {
|
||||
setupOriginHeaders: func() {
|
||||
// Origin has both Cache-Control and Expires headers with longer cache times
|
||||
w.Header().Set(httpheaders.CacheControl, "max-age=7200, public")
|
||||
w.Header().Set(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
|
||||
s.testServer().SetHeaders(httpheaders.CacheControl, "max-age=7200, public")
|
||||
s.testServer().SetHeaders(httpheaders.Expires, time.Now().Add(twoHours).UTC().Format(http.TimeFormat))
|
||||
},
|
||||
timestampOffset: &fortyFiveMinutes, // Shorter than origin headers
|
||||
expectedStatusCode: 200,
|
||||
@@ -352,7 +324,6 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
{
|
||||
name: "NoOriginHeaders",
|
||||
cacheControlPassthrough: false,
|
||||
setupOriginHeaders: func(w http.ResponseWriter) {}, // Origin has no cache headers
|
||||
timestampOffset: nil,
|
||||
expectedStatusCode: 200,
|
||||
validate: func(t *testing.T, res *http.Response) {
|
||||
@@ -363,15 +334,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tc.setupOriginHeaders(w)
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
if tc.setupOriginHeaders != nil {
|
||||
tc.setupOriginHeaders()
|
||||
}
|
||||
|
||||
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
|
||||
|
||||
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
|
||||
s.rwConf().DefaultTTL = 4242
|
||||
@@ -383,9 +352,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
po.Expires = &expires
|
||||
}
|
||||
|
||||
rw := s.execute(ts.URL, nil, po)
|
||||
|
||||
res := rw.Result()
|
||||
res := s.execute("", nil, po)
|
||||
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
|
||||
tc.validate(s.T(), res)
|
||||
})
|
||||
@@ -405,85 +372,64 @@ func (s *HandlerTestSuite) maxAgeValue(res *http.Response) time.Duration {
|
||||
|
||||
// TestHandlerSecurityHeaders tests the security headers set by the streaming service.
|
||||
func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
|
||||
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
res := s.execute("", nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
s.Require().Equal(http.StatusOK, res.StatusCode)
|
||||
s.Require().Equal("script-src 'none'", res.Header.Get(httpheaders.ContentSecurityPolicy))
|
||||
}
|
||||
|
||||
// TestHandlerErrorResponse tests the error responses from the streaming service.
|
||||
func (s *HandlerTestSuite) TestHandlerErrorResponse() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("Not Found"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetStatusCode(http.StatusNotFound).SetBody([]byte("Not Found"))
|
||||
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
res := s.execute("", nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(404, res.StatusCode)
|
||||
s.Require().Equal(http.StatusNotFound, res.StatusCode)
|
||||
}
|
||||
|
||||
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
|
||||
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
||||
s.config().CookiePassthrough = true
|
||||
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify cookies are passed through
|
||||
cookie, cerr := r.Cookie("test_cookie")
|
||||
if cerr == nil {
|
||||
s.Equal("test_value", cookie.Value)
|
||||
}
|
||||
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().
|
||||
SetHeaders(httpheaders.Cookie, "test_cookie=test_value").
|
||||
SetHook(func(r *http.Request, rw http.ResponseWriter) {
|
||||
// Verify cookies are passed through
|
||||
cookie, cerr := r.Cookie("test_cookie")
|
||||
if cerr == nil {
|
||||
s.Equal("test_value", cookie.Value)
|
||||
}
|
||||
}).SetBody(data)
|
||||
|
||||
h := make(http.Header)
|
||||
h.Set(httpheaders.Cookie, "test_cookie=test_value")
|
||||
|
||||
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||
res := s.execute("", h, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
}
|
||||
|
||||
// TestHandlerCanonicalHeader tests that the canonical header is set correctly
|
||||
func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
|
||||
data := s.readTestFile("test1.png")
|
||||
data := s.testData.Read("test1.png")
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpheaders.ContentType, "image/png")
|
||||
w.WriteHeader(200)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
s.testServer().SetHeaders(httpheaders.ContentType, "image/png").SetBody(data)
|
||||
|
||||
for _, sc := range []bool{true, false} {
|
||||
s.rwConf().SetCanonicalHeader = sc
|
||||
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
res := s.execute("", nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
|
||||
if sc {
|
||||
s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, ts.URL))
|
||||
s.Require().Contains(res.Header.Get(httpheaders.Link), fmt.Sprintf(`<%s>; rel="canonical"`, s.testServer().URL()))
|
||||
} else {
|
||||
s.Require().Empty(res.Header.Get(httpheaders.Link))
|
||||
}
|
||||
|
@@ -6,17 +6,13 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
@@ -25,88 +21,70 @@ import (
|
||||
)
|
||||
|
||||
type ImageDataTestSuite struct {
|
||||
suite.Suite
|
||||
testutil.LazySuite
|
||||
|
||||
server *httptest.Server
|
||||
fetcherCfg testutil.LazyObj[*fetcher.Config]
|
||||
factory testutil.LazyObj[*Factory]
|
||||
testServer testutil.LazyTestServer
|
||||
|
||||
status int
|
||||
data []byte
|
||||
header http.Header
|
||||
check func(*http.Request)
|
||||
factory *Factory
|
||||
|
||||
defaultData []byte
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) SetupSuite() {
|
||||
config.Reset()
|
||||
config.ClientKeepAliveTimeout = 0
|
||||
s.data = testutil.NewTestDataProvider(s.T).Read("test1.jpg")
|
||||
|
||||
f, err := os.Open("../testdata/test1.jpg")
|
||||
s.Require().NoError(err)
|
||||
defer f.Close()
|
||||
s.fetcherCfg, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*fetcher.Config, error) {
|
||||
c := fetcher.NewDefaultConfig()
|
||||
c.Transport.HTTP.AllowLoopbackSourceAddresses = true
|
||||
c.Transport.HTTP.ClientKeepAliveTimeout = 0
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
s.Require().NoError(err)
|
||||
return &c, nil
|
||||
},
|
||||
)
|
||||
|
||||
s.defaultData = data
|
||||
s.factory, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*Factory, error) {
|
||||
fetcher, err := fetcher.New(s.fetcherCfg())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if s.check != nil {
|
||||
s.check(r)
|
||||
}
|
||||
return NewFactory(fetcher), nil
|
||||
},
|
||||
)
|
||||
|
||||
httpheaders.CopyAll(s.header, rw.Header(), true)
|
||||
s.testServer, _ = testutil.NewLazySuiteTestServer(
|
||||
s,
|
||||
func(srv *testutil.TestServer) error {
|
||||
// Default headers and body for 200 OK response
|
||||
srv.SetHeaders(
|
||||
httpheaders.ContentType, "image/jpeg",
|
||||
httpheaders.ContentLength, strconv.Itoa(len(s.data)),
|
||||
).SetBody(s.data)
|
||||
|
||||
data := s.data
|
||||
if data == nil {
|
||||
data = s.defaultData
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
|
||||
rw.WriteHeader(s.status)
|
||||
rw.Write(data)
|
||||
}))
|
||||
|
||||
c, err := fetcher.LoadConfigFromEnv(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
fetcher, err := fetcher.New(c)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.factory = NewFactory(fetcher)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TearDownSuite() {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) SetupTest() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
|
||||
s.status = http.StatusOK
|
||||
s.data = nil
|
||||
s.check = nil
|
||||
|
||||
s.header = http.Header{}
|
||||
s.header.Set("Content-Type", "image/jpeg")
|
||||
|
||||
func (s *ImageDataTestSuite) SetupSubTest() {
|
||||
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
|
||||
s.ResetLazyObjects()
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadStatusOK() {
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(imgdata)
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
|
||||
s.Require().Equal(imagetype.JPEG, imgdata.Format())
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
|
||||
s.status = http.StatusPartialContent
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
contentRange string
|
||||
@@ -114,17 +92,17 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
|
||||
}{
|
||||
{
|
||||
name: "Full Content-Range",
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-1, len(s.data)),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Partial Content-Range, early end",
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.defaultData)-2, len(s.defaultData)),
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/%d", len(s.data)-2, len(s.data)),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Partial Content-Range, late start",
|
||||
contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.defaultData)-1, len(s.defaultData)),
|
||||
contentRange: fmt.Sprintf("bytes 1-%d/%d", len(s.data)-1, len(s.data)),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
@@ -139,39 +117,41 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
|
||||
},
|
||||
{
|
||||
name: "Unknown Content-Range range",
|
||||
contentRange: fmt.Sprintf("bytes */%d", len(s.defaultData)),
|
||||
contentRange: fmt.Sprintf("bytes */%d", len(s.data)),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Unknown Content-Range size, full range",
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-1),
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-1),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Unknown Content-Range size, early end",
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.defaultData)-2),
|
||||
contentRange: fmt.Sprintf("bytes 0-%d/*", len(s.data)-2),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Unknown Content-Range size, late start",
|
||||
contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.defaultData)-1),
|
||||
contentRange: fmt.Sprintf("bytes 1-%d/*", len(s.data)-1),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
s.header.Set("Content-Range", tc.contentRange)
|
||||
s.testServer().
|
||||
SetHeaders(httpheaders.ContentRange, tc.contentRange).
|
||||
SetStatusCode(http.StatusPartialContent)
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
if tc.expectErr {
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
|
||||
s.Require().Equal(http.StatusNotFound, ierrors.Wrap(err, 0).StatusCode())
|
||||
} else {
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(imgdata)
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
|
||||
s.Require().Equal(imagetype.JPEG, imgdata.Format())
|
||||
}
|
||||
})
|
||||
@@ -179,11 +159,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusPartialContent() {
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
|
||||
s.status = http.StatusNotFound
|
||||
s.data = []byte("Not Found")
|
||||
s.header.Set("Content-Type", "text/plain")
|
||||
s.testServer().
|
||||
SetStatusCode(http.StatusNotFound).
|
||||
SetBody([]byte("Not Found")).
|
||||
SetHeaders(httpheaders.ContentType, "text/plain")
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
|
||||
@@ -191,11 +172,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusNotFound() {
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
|
||||
s.status = http.StatusForbidden
|
||||
s.data = []byte("Forbidden")
|
||||
s.header.Set("Content-Type", "text/plain")
|
||||
s.testServer().
|
||||
SetStatusCode(http.StatusForbidden).
|
||||
SetBody([]byte("Forbidden")).
|
||||
SetHeaders(httpheaders.ContentType, "text/plain")
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
|
||||
@@ -203,11 +185,12 @@ func (s *ImageDataTestSuite) TestDownloadStatusForbidden() {
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadStatusInternalServerError() {
|
||||
s.status = http.StatusInternalServerError
|
||||
s.data = []byte("Internal Server Error")
|
||||
s.header.Set("Content-Type", "text/plain")
|
||||
s.testServer().
|
||||
SetStatusCode(http.StatusInternalServerError).
|
||||
SetBody([]byte("Internal Server Error")).
|
||||
SetHeaders(httpheaders.ContentType, "text/plain")
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
|
||||
@@ -221,7 +204,7 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
|
||||
|
||||
serverURL := fmt.Sprintf("http://%s", l.Addr().String())
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), serverURL, "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(500, ierrors.Wrap(err, 0).StatusCode())
|
||||
@@ -229,19 +212,19 @@ func (s *ImageDataTestSuite) TestDownloadUnreachable() {
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadInvalidImage() {
|
||||
s.data = []byte("invalid")
|
||||
s.testServer().SetBody([]byte("invalid"))
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
|
||||
s.Require().Equal(http.StatusUnprocessableEntity, ierrors.Wrap(err, 0).StatusCode())
|
||||
s.Require().Nil(imgdata)
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
|
||||
config.AllowLoopbackSourceAddresses = false
|
||||
s.fetcherCfg().Transport.HTTP.AllowLoopbackSourceAddresses = false
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(404, ierrors.Wrap(err, 0).StatusCode())
|
||||
@@ -249,11 +232,10 @@ func (s *ImageDataTestSuite) TestDownloadSourceAddressNotAllowed() {
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestDownloadImageFileTooLarge() {
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{
|
||||
MaxSrcFileSize: 1,
|
||||
})
|
||||
|
||||
fmt.Println(err)
|
||||
s.Require().Error(err)
|
||||
s.Require().Equal(422, ierrors.Wrap(err, 0).StatusCode())
|
||||
s.Require().Nil(imgdata)
|
||||
@@ -263,39 +245,43 @@ func (s *ImageDataTestSuite) TestDownloadGzip() {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
enc := gzip.NewWriter(buf)
|
||||
_, err := enc.Write(s.defaultData)
|
||||
_, err := enc.Write(s.data)
|
||||
s.Require().NoError(err)
|
||||
err = enc.Close()
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.data = buf.Bytes()
|
||||
s.header.Set("Content-Encoding", "gzip")
|
||||
s.testServer().
|
||||
SetBody(buf.Bytes()).
|
||||
SetHeaders(
|
||||
httpheaders.ContentEncoding, "gzip",
|
||||
httpheaders.ContentLength, strconv.Itoa(buf.Len()), // Update Content-Length
|
||||
)
|
||||
|
||||
imgdata, _, err := s.factory.DownloadSync(context.Background(), s.server.URL, "Test image", DownloadOptions{})
|
||||
imgdata, _, err := s.factory().DownloadSync(context.Background(), s.testServer().URL(), "Test image", DownloadOptions{})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(imgdata)
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
|
||||
s.Require().Equal(imagetype.JPEG, imgdata.Format())
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestFromFile() {
|
||||
imgdata, err := s.factory.NewFromPath("../testdata/test1.jpg")
|
||||
imgdata, err := s.factory().NewFromPath("../testdata/test1.jpg")
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(imgdata)
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
|
||||
s.Require().Equal(imagetype.JPEG, imgdata.Format())
|
||||
}
|
||||
|
||||
func (s *ImageDataTestSuite) TestFromBase64() {
|
||||
b64 := base64.StdEncoding.EncodeToString(s.defaultData)
|
||||
b64 := base64.StdEncoding.EncodeToString(s.data)
|
||||
|
||||
imgdata, err := s.factory.NewFromBase64(b64)
|
||||
imgdata, err := s.factory().NewFromBase64(b64)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(imgdata)
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.defaultData), imgdata.Reader()))
|
||||
s.Require().True(testutil.ReadersEqual(s.T(), bytes.NewReader(s.data), imgdata.Reader()))
|
||||
s.Require().Equal(imagetype.JPEG, imgdata.Format())
|
||||
}
|
||||
|
||||
|
@@ -27,10 +27,7 @@ type ProcessingHandlerTestSuite struct {
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) SetupTest() {
|
||||
config.Reset() // We reset config only at the start of each test
|
||||
|
||||
// NOTE: This must be moved to security config
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
// NOTE: end note
|
||||
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) SetupSubTest() {
|
||||
@@ -142,13 +139,13 @@ func (s *ProcessingHandlerTestSuite) TestSourceNetworkValidation() {
|
||||
|
||||
// We wrap this in a subtest to reset s.router()
|
||||
s.Run("AllowLoopbackSourceAddressesTrue", func() {
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = true
|
||||
res := s.GET(url)
|
||||
s.Require().Equal(http.StatusOK, res.StatusCode)
|
||||
})
|
||||
|
||||
s.Run("AllowLoopbackSourceAddressesFalse", func() {
|
||||
config.AllowLoopbackSourceAddresses = false
|
||||
s.Config().Fetcher.Transport.HTTP.AllowLoopbackSourceAddresses = false
|
||||
res := s.GET(url)
|
||||
s.Require().Equal(http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
@@ -256,7 +253,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughExpires() {
|
||||
config.CacheControlPassthrough = true
|
||||
s.Config().Server.ResponseWriter.CacheControlPassthrough = true
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set(httpheaders.Expires, time.Now().Add(1239*time.Second).UTC().Format(http.TimeFormat))
|
||||
@@ -290,7 +287,7 @@ func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughDisabled() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
|
||||
config.ETagEnabled = false
|
||||
s.Config().Handlers.Processing.ETagEnabled = false
|
||||
|
||||
res := s.GET("/unsafe/rs:fill:4:4/plain/local:///test1.png")
|
||||
|
||||
@@ -299,7 +296,7 @@ func (s *ProcessingHandlerTestSuite) TestETagDisabled() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
|
||||
config.ETagEnabled = true
|
||||
s.Config().Handlers.Processing.ETagEnabled = true
|
||||
|
||||
etag := `"loremipsumdolor"`
|
||||
|
||||
@@ -321,7 +318,8 @@ func (s *ProcessingHandlerTestSuite) TestETagDataMatch() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
|
||||
config.LastModifiedEnabled = true
|
||||
s.Config().Handlers.Processing.LastModifiedEnabled = true
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
|
||||
rw.WriteHeader(200)
|
||||
@@ -335,7 +333,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedEnabled() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
|
||||
config.LastModifiedEnabled = false
|
||||
s.Config().Handlers.Processing.LastModifiedEnabled = false
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set(httpheaders.LastModified, "Wed, 21 Oct 2015 07:28:00 GMT")
|
||||
rw.WriteHeader(200)
|
||||
@@ -349,7 +347,7 @@ func (s *ProcessingHandlerTestSuite) TestLastModifiedDisabled() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedDisabled() {
|
||||
config.LastModifiedEnabled = false
|
||||
s.Config().Handlers.Processing.LastModifiedEnabled = false
|
||||
data := s.TestData.Read("test1.png")
|
||||
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
@@ -368,7 +366,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedD
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedEnabled() {
|
||||
config.LastModifiedEnabled = true
|
||||
s.Config().Handlers.Processing.LastModifiedEnabled = true
|
||||
lastModified := "Wed, 21 Oct 2015 07:28:00 GMT"
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
|
||||
@@ -386,7 +384,7 @@ func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqExactMatchLastModifiedE
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestModifiedSinceReqCompareMoreRecentLastModifiedDisabled() {
|
||||
data := s.TestData.Read("test1.png")
|
||||
config.LastModifiedEnabled = false
|
||||
s.Config().Handlers.Processing.LastModifiedEnabled = false
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
modifiedSince := r.Header.Get(httpheaders.IfModifiedSince)
|
||||
s.Empty(modifiedSince)
|
||||
|
@@ -1082,10 +1082,10 @@ func applyURLOption(po *ProcessingOptions, name string, args []string, usedPrese
|
||||
}
|
||||
|
||||
func applyURLOptions(po *ProcessingOptions, options urlOptions, allowAll bool, usedPresets ...string) error {
|
||||
allowAll = allowAll || len(config.AllowedProcessiongOptions) == 0
|
||||
allowAll = allowAll || len(config.AllowedProcessingOptions) == 0
|
||||
|
||||
for _, opt := range options {
|
||||
if !allowAll && !slices.Contains(config.AllowedProcessiongOptions, opt.Name) {
|
||||
if !allowAll && !slices.Contains(config.AllowedProcessingOptions, opt.Name) {
|
||||
return newForbiddenOptionError("processing", opt.Name)
|
||||
}
|
||||
|
||||
|
@@ -646,7 +646,7 @@ func (s *ProcessingOptionsTestSuite) TestParseBase64URLOnlyPresets() {
|
||||
}
|
||||
|
||||
func (s *ProcessingOptionsTestSuite) TestParseAllowedOptions() {
|
||||
config.AllowedProcessiongOptions = []string{"w", "h", "pr"}
|
||||
config.AllowedProcessingOptions = []string{"w", "h", "pr"}
|
||||
|
||||
presets["test1"] = urlOptions{
|
||||
urlOption{Name: "blur", Args: []string{"0.2"}},
|
||||
|
@@ -13,7 +13,6 @@ type (
|
||||
ImageResolutionError string
|
||||
SecurityOptionsError struct{}
|
||||
SourceURLError string
|
||||
SourceAddressError string
|
||||
)
|
||||
|
||||
func newSignatureError(msg string) error {
|
||||
@@ -75,15 +74,3 @@ func newSourceURLError(imageURL string) error {
|
||||
}
|
||||
|
||||
func (e SourceURLError) Error() string { return string(e) }
|
||||
|
||||
func newSourceAddressError(msg string) error {
|
||||
return ierrors.Wrap(
|
||||
SourceAddressError(msg),
|
||||
1,
|
||||
ierrors.WithStatusCode(http.StatusNotFound),
|
||||
ierrors.WithPublicMessage("Invalid source URL"),
|
||||
ierrors.WithShouldReport(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (e SourceAddressError) Error() string { return string(e) }
|
||||
|
@@ -1,9 +1,6 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
@@ -20,29 +17,3 @@ func VerifySourceURL(imageURL string) error {
|
||||
|
||||
return newSourceURLError(imageURL)
|
||||
}
|
||||
|
||||
func VerifySourceNetwork(addr string) error {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return newSourceAddressError(fmt.Sprintf("Invalid source address: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowLoopbackSourceAddresses && (ip.IsLoopback() || ip.IsUnspecified()) {
|
||||
return newSourceAddressError(fmt.Sprintf("Loopback source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowLinkLocalSourceAddresses && (ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) {
|
||||
return newSourceAddressError(fmt.Sprintf("Link-local source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
if !config.AllowPrivateSourceAddresses && ip.IsPrivate() {
|
||||
return newSourceAddressError(fmt.Sprintf("Private source address is not allowed: %s", addr))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -51,7 +51,7 @@ func NewLazySuiteObj[T any](
|
||||
// Get the [LazySuite] instance
|
||||
lazy := s.Lazy()
|
||||
// Create the [LazyObj] instance
|
||||
obj, cancel := NewLazyObj(lazy, newFn, dropFn...)
|
||||
obj, cancel := newLazyObj(lazy, newFn, dropFn...)
|
||||
// Add cleanup function to the resets list
|
||||
lazy.resets = append(lazy.resets, cancel)
|
||||
|
||||
|
@@ -23,10 +23,10 @@ type LazyObjNew[T any] func() (T, error)
|
||||
// If the object was not yet initialized, the callback is not called.
|
||||
type LazyObjDrop[T any] func(T) error
|
||||
|
||||
// NewLazyObj creates a new [LazyObj] that initializes the object on the first call.
|
||||
// newLazyObj creates a new [LazyObj] that initializes the object on the first call.
|
||||
// It returns a function that can be called to get the object and a cancel function
|
||||
// that can be called to reset the object.
|
||||
func NewLazyObj[T any](
|
||||
func newLazyObj[T any](
|
||||
s LazyObjT,
|
||||
newFn LazyObjNew[T],
|
||||
dropFn ...LazyObjDrop[T],
|
||||
|
@@ -26,9 +26,6 @@ type TestDataProvider struct {
|
||||
|
||||
// New creates a new TestDataProvider
|
||||
func NewTestDataProvider(t TestDataProviderT) *TestDataProvider {
|
||||
// if h, ok := t.(interface{ Helper() }); ok {
|
||||
// h.Helper()
|
||||
// }
|
||||
t().Helper()
|
||||
|
||||
path, err := findProjectRoot()
|
||||
|
122
testutil/test_server.go
Normal file
122
testutil/test_server.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestServerHookFunc is a function type for in-request hooks
|
||||
type TestServerHookFunc func(r *http.Request, rw http.ResponseWriter)
|
||||
|
||||
// Sugar alias
|
||||
type LazyTestServer = LazyObj[*TestServer]
|
||||
|
||||
// TestServer is a syntax sugar wrapper over httptest.Server
|
||||
type TestServer struct {
|
||||
testServer *httptest.Server
|
||||
status int
|
||||
data []byte
|
||||
header http.Header
|
||||
hook TestServerHookFunc
|
||||
}
|
||||
|
||||
// NewLazySuiteTestServer creates a lazy TestServer object for use in test suites
|
||||
func NewLazySuiteTestServer(
|
||||
l LazySuiteFrom,
|
||||
init ...func(*TestServer) error,
|
||||
) (LazyObj[*TestServer], context.CancelFunc) {
|
||||
return NewLazySuiteObj(
|
||||
l,
|
||||
func() (*TestServer, error) {
|
||||
s := NewTestServer()
|
||||
|
||||
if len(init) > 0 {
|
||||
for _, fn := range init {
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err := fn(s)
|
||||
require.NoError(l.Lazy().T(), err, "Failed to reset test server")
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
},
|
||||
func(s *TestServer) error {
|
||||
s.Close()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// New creates and starts new http.TestServer
|
||||
func NewTestServer() *TestServer {
|
||||
ts := &TestServer{
|
||||
status: http.StatusOK,
|
||||
header: make(http.Header),
|
||||
data: nil,
|
||||
hook: nil,
|
||||
}
|
||||
|
||||
return ts.start()
|
||||
}
|
||||
|
||||
// SetStatusCode sets the status code that will be returned by the server
|
||||
func (s *TestServer) SetStatusCode(status int) *TestServer {
|
||||
s.status = status
|
||||
return s
|
||||
}
|
||||
|
||||
// SetBody sets the body that will be returned by the server
|
||||
func (s *TestServer) SetBody(data []byte) *TestServer {
|
||||
s.data = data
|
||||
return s
|
||||
}
|
||||
|
||||
// WithHeader adds headers that will be returned by the server.
|
||||
// Odd arguments are treated as keys, even arguments as values.
|
||||
func (s *TestServer) SetHeaders(kv ...string) *TestServer {
|
||||
for i := 0; i+1 < len(kv); i += 2 {
|
||||
key := kv[i]
|
||||
value := kv[i+1]
|
||||
s.header.Set(key, value)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHook sets a function that will be called on each request. It is called
|
||||
// after headsers are set, but before status and body are written.
|
||||
func (s *TestServer) SetHook(f TestServerHookFunc) *TestServer {
|
||||
s.hook = f
|
||||
return s
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *TestServer) start() *TestServer {
|
||||
s.testServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpheaders.CopyAll(s.header, w.Header(), true)
|
||||
if s.hook != nil {
|
||||
s.hook(r, w)
|
||||
}
|
||||
w.WriteHeader(s.status)
|
||||
w.Write(s.data)
|
||||
}))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Close stops the server
|
||||
func (s *TestServer) Close() {
|
||||
s.testServer.Close()
|
||||
}
|
||||
|
||||
// URL returns the server URL
|
||||
func (s *TestServer) URL() string {
|
||||
return s.testServer.URL
|
||||
}
|
Reference in New Issue
Block a user