mirror of
https://github.com/imgproxy/imgproxy.git
synced 2025-09-28 20:43:54 +02:00
Rebuild headerwriter to server.ResponseWriter
This commit is contained in:
committed by
Sergei Aleksandrovich
parent
53645688fb
commit
1f6d007948
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
||||
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/semaphores"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
)
|
||||
@@ -19,7 +18,6 @@ type HandlerConfigs struct {
|
||||
|
||||
// Config represents an instance configuration
|
||||
type Config struct {
|
||||
HeaderWriter headerwriter.Config
|
||||
Semaphores semaphores.Config
|
||||
FallbackImage auximageprovider.StaticConfig
|
||||
WatermarkImage auximageprovider.StaticConfig
|
||||
@@ -31,7 +29,6 @@ type Config struct {
|
||||
// NewDefaultConfig creates a new default configuration
|
||||
func NewDefaultConfig() Config {
|
||||
return Config{
|
||||
HeaderWriter: headerwriter.NewDefaultConfig(),
|
||||
Semaphores: semaphores.NewDefaultConfig(),
|
||||
FallbackImage: auximageprovider.NewDefaultStaticConfig(),
|
||||
WatermarkImage: auximageprovider.NewDefaultStaticConfig(),
|
||||
@@ -62,10 +59,6 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = headerwriter.LoadConfigFromEnv(&c.HeaderWriter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -22,7 +22,7 @@ func New() *Handler {
|
||||
// Execute handles the health request
|
||||
func (h *Handler) Execute(
|
||||
reqID string,
|
||||
rw http.ResponseWriter,
|
||||
rw server.ResponseWriter,
|
||||
req *http.Request,
|
||||
) error {
|
||||
var (
|
||||
|
@@ -6,11 +6,18 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||
)
|
||||
|
||||
func TestHealthHandler(t *testing.T) {
|
||||
// Create responsewriter.Factory
|
||||
rwConf := responsewriter.NewDefaultConfig()
|
||||
rwf, err := responsewriter.NewFactory(&rwConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a ResponseRecorder to record the response
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
@@ -18,7 +25,7 @@ func TestHealthHandler(t *testing.T) {
|
||||
h := New()
|
||||
|
||||
// Call the handler function directly (no need for actual HTTP request)
|
||||
h.Execute("test-req-id", rr, nil)
|
||||
h.Execute("test-req-id", rwf.NewWriter(rr), nil)
|
||||
|
||||
// Check that we get a valid response (either 200 or 500 depending on vips state)
|
||||
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
)
|
||||
|
||||
//go:embed body.html
|
||||
@@ -21,7 +22,7 @@ func New() *Handler {
|
||||
// Execute handles the landing request
|
||||
func (h *Handler) Execute(
|
||||
reqID string,
|
||||
rw http.ResponseWriter,
|
||||
rw server.ResponseWriter,
|
||||
req *http.Request,
|
||||
) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/html")
|
||||
|
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||
"github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||
@@ -17,11 +16,11 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/security"
|
||||
"github.com/imgproxy/imgproxy/v3/semaphores"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
)
|
||||
|
||||
// HandlerContext provides access to shared handler dependencies
|
||||
type HandlerContext interface {
|
||||
HeaderWriter() *headerwriter.Writer
|
||||
Semaphores() *semaphores.Semaphores
|
||||
FallbackImage() auximageprovider.Provider
|
||||
WatermarkImage() auximageprovider.Provider
|
||||
@@ -56,7 +55,7 @@ func New(
|
||||
// Execute handles the image processing request
|
||||
func (h *Handler) Execute(
|
||||
reqID string,
|
||||
rw http.ResponseWriter,
|
||||
rw server.ResponseWriter,
|
||||
req *http.Request,
|
||||
) error {
|
||||
// Increment the number of requests in progress
|
||||
@@ -86,7 +85,6 @@ func (h *Handler) Execute(
|
||||
po: po,
|
||||
imageURL: imageURL,
|
||||
monitoringMeta: mm,
|
||||
hwr: h.HeaderWriter().NewRequest(),
|
||||
}
|
||||
|
||||
return hReq.execute(ctx)
|
||||
|
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||
@@ -23,12 +22,11 @@ type request struct {
|
||||
|
||||
reqID string
|
||||
req *http.Request
|
||||
rw http.ResponseWriter
|
||||
rw server.ResponseWriter
|
||||
config *Config
|
||||
po *options.ProcessingOptions
|
||||
imageURL string
|
||||
monitoringMeta monitoring.Meta
|
||||
hwr *headerwriter.Request
|
||||
}
|
||||
|
||||
// execute handles the actual processing logic
|
||||
@@ -84,13 +82,13 @@ func (r *request) execute(ctx context.Context) error {
|
||||
var nmErr fetcher.NotModifiedError
|
||||
|
||||
if errors.As(err, &nmErr) {
|
||||
r.hwr.SetOriginHeaders(nmErr.Headers())
|
||||
r.rw.SetOriginHeaders(nmErr.Headers())
|
||||
|
||||
return r.respondWithNotModified()
|
||||
}
|
||||
|
||||
// Prepare to write image response headers
|
||||
r.hwr.SetOriginHeaders(originHeaders)
|
||||
r.rw.SetOriginHeaders(originHeaders)
|
||||
|
||||
// If error is not related to NotModified, respond with fallback image and replace image data
|
||||
if err != nil {
|
||||
@@ -123,7 +121,7 @@ func (r *request) execute(ctx context.Context) error {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing))
|
||||
}
|
||||
|
||||
// Write debug headers. It seems unlogical to move they to headerwriter since they're
|
||||
// Write debug headers. It seems unlogical to move they to responsewriter since they're
|
||||
// not used anywhere else.
|
||||
err = r.writeDebugHeaders(result, originData)
|
||||
if err != nil {
|
||||
|
@@ -123,8 +123,8 @@ func (r *request) handleDownloadError(
|
||||
headers.Del(httpheaders.Expires)
|
||||
headers.Del(httpheaders.LastModified)
|
||||
|
||||
r.hwr.SetOriginHeaders(headers)
|
||||
r.hwr.SetIsFallbackImage()
|
||||
r.rw.SetOriginHeaders(headers)
|
||||
r.rw.SetIsFallbackImage()
|
||||
|
||||
return data, statusCode, nil
|
||||
}
|
||||
@@ -186,19 +186,17 @@ func (r *request) writeDebugHeaders(result *processing.Result, originData imaged
|
||||
|
||||
// respondWithNotModified writes not-modified response
|
||||
func (r *request) respondWithNotModified() error {
|
||||
r.hwr.SetExpires(r.po.Expires)
|
||||
r.hwr.SetVary()
|
||||
r.rw.SetExpires(r.po.Expires)
|
||||
r.rw.SetVary()
|
||||
|
||||
if r.config.LastModifiedEnabled {
|
||||
r.hwr.Passthrough(httpheaders.LastModified)
|
||||
r.rw.Passthrough(httpheaders.LastModified)
|
||||
}
|
||||
|
||||
if r.config.ETagEnabled {
|
||||
r.hwr.Passthrough(httpheaders.Etag)
|
||||
r.rw.Passthrough(httpheaders.Etag)
|
||||
}
|
||||
|
||||
r.hwr.Write(r.rw)
|
||||
|
||||
r.rw.WriteHeader(http.StatusNotModified)
|
||||
|
||||
server.LogResponse(
|
||||
@@ -221,29 +219,27 @@ func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageDat
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize))
|
||||
}
|
||||
|
||||
r.hwr.SetContentType(resultData.Format().Mime())
|
||||
r.hwr.SetContentLength(resultSize)
|
||||
r.hwr.SetContentDisposition(
|
||||
r.rw.SetContentType(resultData.Format().Mime())
|
||||
r.rw.SetContentLength(resultSize)
|
||||
r.rw.SetContentDisposition(
|
||||
r.imageURL,
|
||||
r.po.Filename,
|
||||
resultData.Format().Ext(),
|
||||
"",
|
||||
r.po.ReturnAttachment,
|
||||
)
|
||||
r.hwr.SetExpires(r.po.Expires)
|
||||
r.hwr.SetVary()
|
||||
r.hwr.SetCanonical(r.imageURL)
|
||||
r.rw.SetExpires(r.po.Expires)
|
||||
r.rw.SetVary()
|
||||
r.rw.SetCanonical(r.imageURL)
|
||||
|
||||
if r.config.LastModifiedEnabled {
|
||||
r.hwr.Passthrough(httpheaders.LastModified)
|
||||
r.rw.Passthrough(httpheaders.LastModified)
|
||||
}
|
||||
|
||||
if r.config.ETagEnabled {
|
||||
r.hwr.Passthrough(httpheaders.Etag)
|
||||
r.rw.Passthrough(httpheaders.Etag)
|
||||
}
|
||||
|
||||
r.hwr.Write(r.rw)
|
||||
|
||||
r.rw.WriteHeader(statusCode)
|
||||
|
||||
_, err = io.Copy(r.rw, resultData.Reader())
|
||||
|
@@ -6,16 +6,16 @@ import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/cookies"
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||
"github.com/imgproxy/imgproxy/v3/monitoring/stats"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,7 +37,6 @@ var (
|
||||
type Handler struct {
|
||||
config *Config // Configuration for the streamer
|
||||
fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching
|
||||
hw *headerwriter.Writer // Configured HeaderWriter instance
|
||||
}
|
||||
|
||||
// request holds the parameters and state for a single streaming request
|
||||
@@ -47,12 +46,11 @@ type request struct {
|
||||
imageURL string
|
||||
reqID string
|
||||
po *options.ProcessingOptions
|
||||
rw http.ResponseWriter
|
||||
hw *headerwriter.Request
|
||||
rw server.ResponseWriter
|
||||
}
|
||||
|
||||
// New creates new handler object
|
||||
func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Handler, error) {
|
||||
func New(config *Config, fetcher *fetcher.Fetcher) (*Handler, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -60,7 +58,6 @@ func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Ha
|
||||
return &Handler{
|
||||
fetcher: fetcher,
|
||||
config: config,
|
||||
hw: hw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -71,7 +68,7 @@ func (s *Handler) Execute(
|
||||
imageURL string,
|
||||
reqID string,
|
||||
po *options.ProcessingOptions,
|
||||
rw http.ResponseWriter,
|
||||
rw server.ResponseWriter,
|
||||
) error {
|
||||
stream := &request{
|
||||
handler: s,
|
||||
@@ -80,7 +77,6 @@ func (s *Handler) Execute(
|
||||
reqID: reqID,
|
||||
po: po,
|
||||
rw: rw,
|
||||
hw: s.hw.NewRequest(),
|
||||
}
|
||||
|
||||
return stream.execute(ctx)
|
||||
@@ -118,17 +114,14 @@ func (s *request) execute(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Output streaming response headers
|
||||
s.hw.SetOriginHeaders(res.Header)
|
||||
s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
|
||||
s.hw.SetContentLength(int(res.ContentLength))
|
||||
s.hw.SetCanonical(s.imageURL)
|
||||
s.hw.SetExpires(s.po.Expires)
|
||||
s.rw.SetOriginHeaders(res.Header)
|
||||
s.rw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
|
||||
s.rw.SetContentLength(int(res.ContentLength))
|
||||
s.rw.SetCanonical(s.imageURL)
|
||||
s.rw.SetExpires(s.po.Expires)
|
||||
|
||||
// Set the Content-Disposition header
|
||||
s.setContentDisposition(r.URL().Path, res, s.hw)
|
||||
|
||||
// Write headers from writer
|
||||
s.hw.Write(s.rw)
|
||||
s.setContentDisposition(r.URL().Path, res)
|
||||
|
||||
// Copy the status code from the original response
|
||||
s.rw.WriteHeader(res.StatusCode)
|
||||
@@ -158,7 +151,7 @@ func (s *request) getImageRequestHeaders() http.Header {
|
||||
}
|
||||
|
||||
// setContentDisposition writes the headers to the response writer
|
||||
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response, hw *headerwriter.Request) {
|
||||
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response) {
|
||||
// Try to set correct Content-Disposition file name and extension
|
||||
if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 {
|
||||
return
|
||||
@@ -166,7 +159,7 @@ func (s *request) setContentDisposition(imagePath string, serverResponse *http.R
|
||||
|
||||
ct := serverResponse.Header.Get(httpheaders.ContentType)
|
||||
|
||||
hw.SetContentDisposition(
|
||||
s.rw.SetContentDisposition(
|
||||
imagePath,
|
||||
s.po.Filename,
|
||||
"",
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -17,9 +16,10 @@ import (
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,14 +27,54 @@ const (
|
||||
)
|
||||
|
||||
type HandlerTestSuite struct {
|
||||
suite.Suite
|
||||
handler *Handler
|
||||
testutil.LazySuite
|
||||
|
||||
rwConf testutil.LazyObj[*responsewriter.Config]
|
||||
rwFactory testutil.LazyObj[*responsewriter.Factory]
|
||||
|
||||
config testutil.LazyObj[*Config]
|
||||
handler testutil.LazyObj[*Handler]
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) SetupSuite() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
|
||||
s.rwConf, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*responsewriter.Config, error) {
|
||||
c := responsewriter.NewDefaultConfig()
|
||||
return &c, nil
|
||||
},
|
||||
)
|
||||
|
||||
s.rwFactory, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*responsewriter.Factory, error) {
|
||||
return responsewriter.NewFactory(s.rwConf())
|
||||
},
|
||||
)
|
||||
|
||||
s.config, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*Config, error) {
|
||||
c := NewDefaultConfig()
|
||||
return &c, nil
|
||||
},
|
||||
)
|
||||
|
||||
s.handler, _ = testutil.NewLazySuiteObj(
|
||||
s,
|
||||
func() (*Handler, error) {
|
||||
fc := fetcher.NewDefaultConfig()
|
||||
|
||||
fetcher, err := fetcher.New(&fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
return New(s.config(), fetcher)
|
||||
},
|
||||
)
|
||||
|
||||
// Silence logs during tests
|
||||
logrus.SetOutput(io.Discard)
|
||||
}
|
||||
@@ -47,21 +87,11 @@ func (s *HandlerTestSuite) TearDownSuite() {
|
||||
func (s *HandlerTestSuite) SetupTest() {
|
||||
config.Reset()
|
||||
config.AllowLoopbackSourceAddresses = true
|
||||
}
|
||||
|
||||
fc := fetcher.NewDefaultConfig()
|
||||
|
||||
fetcher, err := fetcher.New(&fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
cfg := NewDefaultConfig()
|
||||
|
||||
hwc := headerwriter.NewDefaultConfig()
|
||||
hw, err := headerwriter.New(&hwc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
h, err := New(&cfg, hw, fetcher)
|
||||
s.Require().NoError(err)
|
||||
s.handler = h
|
||||
func (s *HandlerTestSuite) SetupSubTest() {
|
||||
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
|
||||
s.ResetLazyObjects()
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
||||
@@ -70,6 +100,24 @@ func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *HandlerTestSuite) execute(
|
||||
imageURL string,
|
||||
header http.Header,
|
||||
po *options.ProcessingOptions,
|
||||
) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
httpheaders.CopyAll(header, req.Header, true)
|
||||
|
||||
ctx := s.T().Context()
|
||||
rw := httptest.NewRecorder()
|
||||
rww := s.rwFactory().NewWriter(rw)
|
||||
|
||||
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
|
||||
s.Require().NoError(err)
|
||||
|
||||
return rw
|
||||
}
|
||||
|
||||
// TestHandlerBasicRequest checks basic streaming request
|
||||
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
||||
data := s.readTestFile("test1.png")
|
||||
@@ -81,12 +129,7 @@ func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -114,12 +157,7 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -148,16 +186,12 @@ func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(httpheaders.IfNoneMatch, etag)
|
||||
req.Header.Set(httpheaders.AcceptEncoding, "gzip")
|
||||
req.Header.Set(httpheaders.Range, "bytes=*")
|
||||
h := make(http.Header)
|
||||
h.Set(httpheaders.IfNoneMatch, etag)
|
||||
h.Set(httpheaders.AcceptEncoding, "gzip")
|
||||
h.Set(httpheaders.Range, "bytes=*")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -175,8 +209,6 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{
|
||||
Filename: "custom_name",
|
||||
ReturnAttachment: true,
|
||||
@@ -184,8 +216,7 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
||||
|
||||
// Use a URL with a .png extension to help content disposition logic
|
||||
imageURL := ts.URL + "/test.png"
|
||||
err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(imageURL, nil, po)
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -342,25 +373,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
||||
s.Require().NoError(err)
|
||||
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
|
||||
s.rwConf().DefaultTTL = 4242
|
||||
|
||||
fetcher, err := fetcher.New(fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
cfg := NewDefaultConfig()
|
||||
hwc := headerwriter.NewDefaultConfig()
|
||||
hwc.CacheControlPassthrough = tc.cacheControlPassthrough
|
||||
hwc.DefaultTTL = 4242
|
||||
|
||||
hw, err := headerwriter.New(&hwc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
handler, err := New(&cfg, hw, fetcher)
|
||||
s.Require().NoError(err)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
if tc.timestampOffset != nil {
|
||||
@@ -368,8 +383,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
||||
po.Expires = &expires
|
||||
}
|
||||
|
||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, po)
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
|
||||
@@ -400,12 +414,7 @@ func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -420,12 +429,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(404, res.StatusCode)
|
||||
@@ -433,21 +437,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
|
||||
|
||||
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
|
||||
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
fetcher, err := fetcher.New(fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
cfg := NewDefaultConfig()
|
||||
cfg.CookiePassthrough = true
|
||||
|
||||
hwc := headerwriter.NewDefaultConfig()
|
||||
hw, err := headerwriter.New(&hwc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
handler, err := New(&cfg, hw, fetcher)
|
||||
s.Require().NoError(err)
|
||||
s.config().CookiePassthrough = true
|
||||
|
||||
data := s.readTestFile("test1.png")
|
||||
|
||||
@@ -464,13 +454,10 @@ func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
h := make(http.Header)
|
||||
h.Set(httpheaders.Cookie, "test_cookie=test_value")
|
||||
|
||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
@@ -488,29 +475,9 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
|
||||
defer ts.Close()
|
||||
|
||||
for _, sc := range []bool{true, false} {
|
||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
||||
s.Require().NoError(err)
|
||||
s.rwConf().SetCanonicalHeader = sc
|
||||
|
||||
fetcher, err := fetcher.New(fc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
cfg := NewDefaultConfig()
|
||||
hwc := headerwriter.NewDefaultConfig()
|
||||
|
||||
hwc.SetCanonicalHeader = sc
|
||||
|
||||
hw, err := headerwriter.New(&hwc)
|
||||
s.Require().NoError(err)
|
||||
|
||||
handler, err := New(&cfg, hw, fetcher)
|
||||
s.Require().NoError(err)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
po := &options.ProcessingOptions{}
|
||||
|
||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
||||
s.Require().NoError(err)
|
||||
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||
|
||||
res := rw.Result()
|
||||
s.Require().Equal(200, res.StatusCode)
|
||||
|
@@ -1,62 +0,0 @@
|
||||
package headerwriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/ensure"
|
||||
)
|
||||
|
||||
// Config is the package-local configuration
|
||||
type Config struct {
|
||||
SetCanonicalHeader bool // Indicates whether to set the canonical header
|
||||
DefaultTTL int // Default Cache-Control max-age= value for cached images
|
||||
FallbackImageTTL int // TTL for images served as fallbacks
|
||||
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
|
||||
EnableClientHints bool // Enable Vary header
|
||||
SetVaryAccept bool // Whether to include Accept in Vary header
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a new Config instance with default values.
|
||||
func NewDefaultConfig() Config {
|
||||
return Config{
|
||||
SetCanonicalHeader: false,
|
||||
DefaultTTL: 31536000,
|
||||
FallbackImageTTL: 0,
|
||||
CacheControlPassthrough: false,
|
||||
EnableClientHints: false,
|
||||
SetVaryAccept: false,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfigFromEnv overrides configuration variables from environment
|
||||
func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||
c = ensure.Ensure(c, NewDefaultConfig)
|
||||
|
||||
c.SetCanonicalHeader = config.SetCanonicalHeader
|
||||
c.DefaultTTL = config.TTL
|
||||
c.FallbackImageTTL = config.FallbackImageTTL
|
||||
c.CacheControlPassthrough = config.CacheControlPassthrough
|
||||
c.EnableClientHints = config.EnableClientHints
|
||||
c.SetVaryAccept = config.AutoWebp ||
|
||||
config.EnforceWebp ||
|
||||
config.AutoAvif ||
|
||||
config.EnforceAvif ||
|
||||
config.AutoJxl ||
|
||||
config.EnforceJxl
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Validate checks config for errors
|
||||
func (c *Config) Validate() error {
|
||||
if c.DefaultTTL < 0 {
|
||||
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
|
||||
}
|
||||
|
||||
if c.FallbackImageTTL < 0 {
|
||||
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,214 +0,0 @@
|
||||
// headerwriter is responsible for writing processing/stream response headers
|
||||
package headerwriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
// Writer is a struct that creates header writer factories.
|
||||
type Writer struct {
|
||||
config *Config
|
||||
varyValue string
|
||||
}
|
||||
|
||||
// Request is a private struct that builds HTTP response headers for a specific request.
|
||||
type Request struct {
|
||||
writer *Writer
|
||||
originHeaders http.Header // Original response headers
|
||||
result http.Header // Headers to be written to the response
|
||||
maxAge int // Current max age for Cache-Control header
|
||||
}
|
||||
|
||||
// New creates a new header writer factory with the provided config.
|
||||
func New(config *Config) (*Writer, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vary := make([]string, 0)
|
||||
|
||||
if config.SetVaryAccept {
|
||||
vary = append(vary, "Accept")
|
||||
}
|
||||
|
||||
if config.EnableClientHints {
|
||||
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
|
||||
}
|
||||
|
||||
varyValue := strings.Join(vary, ", ")
|
||||
|
||||
return &Writer{
|
||||
config: config,
|
||||
varyValue: varyValue,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL.
|
||||
func (w *Writer) NewRequest() *Request {
|
||||
return &Request{
|
||||
writer: w,
|
||||
result: make(http.Header),
|
||||
maxAge: -1,
|
||||
originHeaders: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
// SetOriginHeaders sets the origin headers for the request.
|
||||
func (r *Request) SetOriginHeaders(h http.Header) {
|
||||
r.originHeaders = h
|
||||
}
|
||||
|
||||
// SetIsFallbackImage sets the Fallback-Image header to
|
||||
// indicate that the fallback image was used.
|
||||
func (r *Request) SetIsFallbackImage() {
|
||||
// We set maxAge to FallbackImageTTL if it's explicitly passed
|
||||
if r.writer.config.FallbackImageTTL < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// However, we should not overwrite existing value if set (or greater than ours)
|
||||
if r.maxAge < 0 || r.maxAge > r.writer.config.FallbackImageTTL {
|
||||
r.maxAge = r.writer.config.FallbackImageTTL
|
||||
}
|
||||
}
|
||||
|
||||
// SetExpires sets the TTL from time
|
||||
func (r *Request) SetExpires(expires *time.Time) {
|
||||
if expires == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert current maxAge to time
|
||||
currentMaxAgeTime := time.Now().Add(time.Duration(r.maxAge) * time.Second)
|
||||
|
||||
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
|
||||
if r.maxAge < 0 || expires.Before(currentMaxAgeTime) {
|
||||
r.maxAge = min(r.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
|
||||
}
|
||||
}
|
||||
|
||||
// SetVary sets the Vary header
|
||||
func (r *Request) SetVary() {
|
||||
if len(r.writer.varyValue) > 0 {
|
||||
r.result.Set(httpheaders.Vary, r.writer.varyValue)
|
||||
}
|
||||
}
|
||||
|
||||
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
|
||||
func (r *Request) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
|
||||
value := httpheaders.ContentDispositionValue(
|
||||
originURL,
|
||||
filename,
|
||||
ext,
|
||||
contentType,
|
||||
returnAttachment,
|
||||
)
|
||||
|
||||
if value != "" {
|
||||
r.result.Set(httpheaders.ContentDisposition, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Passthrough copies specified headers from the original response headers to the response headers.
|
||||
func (r *Request) Passthrough(only ...string) {
|
||||
httpheaders.Copy(r.originHeaders, r.result, only)
|
||||
}
|
||||
|
||||
// CopyFrom copies specified headers from the headers object. Please note that
|
||||
// all the past operations may overwrite those values.
|
||||
func (r *Request) CopyFrom(headers http.Header, only []string) {
|
||||
httpheaders.Copy(headers, r.result, only)
|
||||
}
|
||||
|
||||
// SetContentLength sets the Content-Length header
|
||||
func (r *Request) SetContentLength(contentLength int) {
|
||||
if contentLength < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
|
||||
}
|
||||
|
||||
// SetContentType sets the Content-Type header
|
||||
func (r *Request) SetContentType(mime string) {
|
||||
r.result.Set(httpheaders.ContentType, mime)
|
||||
}
|
||||
|
||||
// writeCanonical sets the Link header with the canonical URL.
|
||||
// It is mandatory for any response if enabled in the configuration.
|
||||
func (r *Request) SetCanonical(url string) {
|
||||
if !r.writer.config.SetCanonicalHeader {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
|
||||
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
|
||||
r.result.Set(httpheaders.Link, value)
|
||||
}
|
||||
}
|
||||
|
||||
// setCacheControl sets the Cache-Control header with the specified value.
|
||||
func (r *Request) setCacheControl(value int) bool {
|
||||
if value <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
r.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
|
||||
return true
|
||||
}
|
||||
|
||||
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
|
||||
func (r *Request) setCacheControlNoCache() {
|
||||
r.result.Set(httpheaders.CacheControl, "no-cache")
|
||||
}
|
||||
|
||||
// setCacheControlPassthrough sets the Cache-Control header from the request
|
||||
// if passthrough is enabled in the configuration.
|
||||
func (r *Request) setCacheControlPassthrough() bool {
|
||||
if !r.writer.config.CacheControlPassthrough || r.maxAge > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" {
|
||||
r.result.Set(httpheaders.CacheControl, val)
|
||||
return true
|
||||
}
|
||||
|
||||
if val := r.originHeaders.Get(httpheaders.Expires); val != "" {
|
||||
if t, err := time.Parse(http.TimeFormat, val); err == nil {
|
||||
maxAge := max(0, int(time.Until(t).Seconds()))
|
||||
return r.setCacheControl(maxAge)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// setCSP sets the Content-Security-Policy header to prevent script execution.
|
||||
func (r *Request) setCSP() {
|
||||
r.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
|
||||
}
|
||||
|
||||
// Write writes the headers to the response writer. It does not overwrite
|
||||
// target headers, which were set outside the header writer.
|
||||
func (r *Request) Write(rw http.ResponseWriter) {
|
||||
// Then, let's try to set Cache-Control using priority order
|
||||
switch {
|
||||
case r.setCacheControl(r.maxAge): // First, try set explicit
|
||||
case r.setCacheControlPassthrough(): // Try to pick up from request headers
|
||||
case r.setCacheControl(r.writer.config.DefaultTTL): // Fallback to default value
|
||||
default:
|
||||
r.setCacheControlNoCache() // By default we use no-cache
|
||||
}
|
||||
|
||||
r.setCSP()
|
||||
|
||||
// Copy all headers to the response without overwriting existing ones
|
||||
httpheaders.CopyAll(r.result, rw.Header(), false)
|
||||
}
|
14
imgproxy.go
14
imgproxy.go
@@ -11,7 +11,6 @@ import (
|
||||
landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing"
|
||||
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
||||
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/memory"
|
||||
"github.com/imgproxy/imgproxy/v3/monitoring/prometheus"
|
||||
@@ -34,7 +33,6 @@ type ImgproxyHandlers struct {
|
||||
|
||||
// Imgproxy holds all the components needed for imgproxy to function
|
||||
type Imgproxy struct {
|
||||
headerWriter *headerwriter.Writer
|
||||
semaphores *semaphores.Semaphores
|
||||
fallbackImage auximageprovider.Provider
|
||||
watermarkImage auximageprovider.Provider
|
||||
@@ -46,11 +44,6 @@ type Imgproxy struct {
|
||||
|
||||
// New creates a new imgproxy instance
|
||||
func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
||||
headerWriter, err := headerwriter.New(&config.HeaderWriter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fetcher, err := fetcher.New(&config.Fetcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -74,7 +67,6 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
||||
}
|
||||
|
||||
imgproxy := &Imgproxy{
|
||||
headerWriter: headerWriter,
|
||||
semaphores: semaphores,
|
||||
fallbackImage: fallbackImage,
|
||||
watermarkImage: watermarkImage,
|
||||
@@ -86,7 +78,7 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
||||
imgproxy.handlers.Health = healthhandler.New()
|
||||
imgproxy.handlers.Landing = landinghandler.New()
|
||||
|
||||
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, headerWriter, fetcher)
|
||||
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, fetcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -180,10 +172,6 @@ func (i *Imgproxy) startMemoryTicker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Imgproxy) HeaderWriter() *headerwriter.Writer {
|
||||
return i.headerWriter
|
||||
}
|
||||
|
||||
func (i *Imgproxy) Semaphores() *semaphores.Semaphores {
|
||||
return i.semaphores
|
||||
}
|
||||
|
@@ -238,7 +238,7 @@ func (s *ProcessingHandlerTestSuite) TestErrorSavingToSVG() {
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
|
||||
s.Config().HeaderWriter.CacheControlPassthrough = true
|
||||
s.Config().Server.ResponseWriter.CacheControlPassthrough = true
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public")
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/ensure"
|
||||
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||
)
|
||||
|
||||
// Config represents HTTP server config
|
||||
@@ -18,7 +19,6 @@ type Config struct {
|
||||
PathPrefix string // Path prefix for the server
|
||||
MaxClients int // Maximum number of concurrent clients
|
||||
ReadRequestTimeout time.Duration // Timeout for reading requests
|
||||
WriteResponseTimeout time.Duration // Timeout for writing responses
|
||||
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
|
||||
GracefulTimeout time.Duration // Timeout for graceful shutdown
|
||||
CORSAllowOrigin string // CORS allowed origin
|
||||
@@ -27,6 +27,8 @@ type Config struct {
|
||||
SocketReusePort bool // Enable SO_REUSEPORT socket option
|
||||
HealthCheckPath string // Health check path from config
|
||||
|
||||
ResponseWriter responsewriter.Config // Response writer config
|
||||
|
||||
// TODO: We are not sure where to put it yet
|
||||
FreeMemoryInterval time.Duration // Interval for freeing memory
|
||||
LogMemStats bool // Log memory stats
|
||||
@@ -41,7 +43,6 @@ func NewDefaultConfig() Config {
|
||||
MaxClients: 2048,
|
||||
ReadRequestTimeout: 10 * time.Second,
|
||||
KeepAliveTimeout: 10 * time.Second,
|
||||
WriteResponseTimeout: 10 * time.Second,
|
||||
GracefulTimeout: 20 * time.Second,
|
||||
CORSAllowOrigin: "",
|
||||
Secret: "",
|
||||
@@ -50,6 +51,8 @@ func NewDefaultConfig() Config {
|
||||
HealthCheckPath: "",
|
||||
FreeMemoryInterval: 10 * time.Second,
|
||||
LogMemStats: false,
|
||||
|
||||
ResponseWriter: responsewriter.NewDefaultConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +75,10 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||
c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second
|
||||
c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0
|
||||
|
||||
if _, err := responsewriter.LoadConfigFromEnv(&c.ResponseWriter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -89,10 +96,6 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
|
||||
}
|
||||
|
||||
if c.WriteResponseTimeout <= 0 {
|
||||
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
|
||||
}
|
||||
|
||||
if c.KeepAliveTimeout < 0 {
|
||||
return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
|
||||
}
|
||||
|
@@ -23,10 +23,13 @@ func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
ctx, cancel, rw := monitoring.StartRequest(req.Context(), rw, req)
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
|
||||
defer cancel()
|
||||
|
||||
// Replace rw.ResponseWriter with new one returned from monitoring
|
||||
rw.SetHTTPResponseWriter(newRw)
|
||||
|
||||
return h(reqID, rw, req.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
@@ -37,7 +40,7 @@ func (r *Router) WithCORS(h RouteHandler) RouteHandler {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
|
||||
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
|
||||
|
||||
@@ -53,7 +56,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
||||
|
||||
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
|
||||
return h(reqID, rw, req)
|
||||
} else {
|
||||
@@ -64,7 +67,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
||||
|
||||
// WithPanic recovers panic and converts it to normal error
|
||||
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
|
||||
return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
|
||||
defer func() {
|
||||
// try to recover from panic
|
||||
rerr := recover()
|
||||
@@ -94,7 +97,7 @@ func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
||||
// WithReportError handles error reporting.
|
||||
// It should be placed after `WithMonitoring`, but before `WithPanic`.
|
||||
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
// Open the error context
|
||||
ctx := errorreport.StartRequest(req)
|
||||
req = req.WithContext(ctx)
|
||||
|
87
server/responsewriter/config.go
Normal file
87
server/responsewriter/config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/ensure"
|
||||
)
|
||||
|
||||
// Config holds configuration for response writer
|
||||
type Config struct {
|
||||
SetCanonicalHeader bool // Indicates whether to set the canonical header
|
||||
DefaultTTL int // Default Cache-Control max-age= value for cached images
|
||||
FallbackImageTTL int // TTL for images served as fallbacks
|
||||
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
|
||||
VaryValue string // Value for Vary header
|
||||
WriteResponseTimeout time.Duration // Timeout for response write operations
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a new Config instance with default values.
|
||||
func NewDefaultConfig() Config {
|
||||
return Config{
|
||||
SetCanonicalHeader: false,
|
||||
DefaultTTL: 31536000,
|
||||
FallbackImageTTL: 0,
|
||||
CacheControlPassthrough: false,
|
||||
VaryValue: "",
|
||||
WriteResponseTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfigFromEnv overrides configuration variables from environment
|
||||
func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||
c = ensure.Ensure(c, NewDefaultConfig)
|
||||
|
||||
c.SetCanonicalHeader = config.SetCanonicalHeader
|
||||
c.DefaultTTL = config.TTL
|
||||
c.FallbackImageTTL = config.FallbackImageTTL
|
||||
c.CacheControlPassthrough = config.CacheControlPassthrough
|
||||
c.WriteResponseTimeout = time.Duration(config.WriteResponseTimeout) * time.Second
|
||||
|
||||
vary := make([]string, 0)
|
||||
|
||||
if c.envEnableFormatDetection() {
|
||||
vary = append(vary, "Accept")
|
||||
}
|
||||
|
||||
if c.envEnableClientHints() {
|
||||
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
|
||||
}
|
||||
|
||||
c.VaryValue = strings.Join(vary, ", ")
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) envEnableFormatDetection() bool {
|
||||
return config.AutoWebp ||
|
||||
config.EnforceWebp ||
|
||||
config.AutoAvif ||
|
||||
config.EnforceAvif ||
|
||||
config.AutoJxl ||
|
||||
config.EnforceJxl
|
||||
}
|
||||
|
||||
func (c *Config) envEnableClientHints() bool {
|
||||
return config.EnableClientHints
|
||||
}
|
||||
|
||||
// Validate checks config for errors
|
||||
func (c *Config) Validate() error {
|
||||
if c.DefaultTTL < 0 {
|
||||
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
|
||||
}
|
||||
|
||||
if c.FallbackImageTTL < 0 {
|
||||
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
|
||||
}
|
||||
|
||||
if c.WriteResponseTimeout <= 0 {
|
||||
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
114
server/responsewriter/config_test.go
Normal file
114
server/responsewriter/config_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ResponseWriterConfigSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *ResponseWriterConfigSuite) SetupSuite() {
|
||||
logrus.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
func (s *ResponseWriterConfigSuite) TearDownSuite() {
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
func (s *ResponseWriterConfigSuite) TestLoadingVaryValueFromEnv() {
|
||||
defaultEnv := map[string]string{
|
||||
"IMGPROXY_AUTO_WEBP": "",
|
||||
"IMGPROXY_ENFORCE_WEBP": "",
|
||||
"IMGPROXY_AUTO_AVIF": "",
|
||||
"IMGPROXY_ENFORCE_AVIF": "",
|
||||
"IMGPROXY_AUTO_JXL": "",
|
||||
"IMGPROXY_ENFORCE_JXL": "",
|
||||
"IMGPROXY_ENABLE_CLIENT_HINTS": "",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
env map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "AutoWebP",
|
||||
env: map[string]string{"IMGPROXY_AUTO_WEBP": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "EnforceWebP",
|
||||
env: map[string]string{"IMGPROXY_ENFORCE_WEBP": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "AutoAVIF",
|
||||
env: map[string]string{"IMGPROXY_AUTO_AVIF": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "EnforceAVIF",
|
||||
env: map[string]string{"IMGPROXY_ENFORCE_AVIF": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "AutoJXL",
|
||||
env: map[string]string{"IMGPROXY_AUTO_JXL": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "EnforceJXL",
|
||||
env: map[string]string{"IMGPROXY_ENFORCE_JXL": "true"},
|
||||
expected: "Accept",
|
||||
},
|
||||
{
|
||||
name: "EnableClientHints",
|
||||
env: map[string]string{"IMGPROXY_ENABLE_CLIENT_HINTS": "true"},
|
||||
expected: "Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||
},
|
||||
{
|
||||
name: "Combined",
|
||||
env: map[string]string{
|
||||
"IMGPROXY_AUTO_WEBP": "true",
|
||||
"IMGPROXY_ENABLE_CLIENT_HINTS": "true",
|
||||
},
|
||||
expected: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(fmt.Sprintf("%v", tc.env), func() {
|
||||
// Set default environment variables
|
||||
for key, value := range defaultEnv {
|
||||
s.T().Setenv(key, value)
|
||||
}
|
||||
// Set environment variables
|
||||
for key, value := range tc.env {
|
||||
s.T().Setenv(key, value)
|
||||
}
|
||||
|
||||
// TODO: Remove when we removed global config
|
||||
config.Reset()
|
||||
config.Configure()
|
||||
|
||||
// Load config
|
||||
cfg, err := LoadConfigFromEnv(nil)
|
||||
|
||||
// Assert expected values
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(tc.expected, cfg.VaryValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWriterConfig(t *testing.T) {
|
||||
suite.Run(t, new(ResponseWriterConfigSuite))
|
||||
}
|
30
server/responsewriter/factory.go
Normal file
30
server/responsewriter/factory.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package responsewriter
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Factory is a struct that creates response writers.
|
||||
type Factory struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewFactory(config *Config) (*Factory, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Factory{config}, nil
|
||||
}
|
||||
|
||||
// NewWriter wraps [http.ResponseWriter] into [Writer].
|
||||
func (f *Factory) NewWriter(rw http.ResponseWriter) *Writer {
|
||||
w := &Writer{
|
||||
config: f.config,
|
||||
result: make(http.Header),
|
||||
originHeaders: make(http.Header),
|
||||
maxAge: -1,
|
||||
}
|
||||
|
||||
w.SetHTTPResponseWriter(rw)
|
||||
|
||||
return w
|
||||
}
|
226
server/responsewriter/writer.go
Normal file
226
server/responsewriter/writer.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
// Just aliases for [http.ResponseWriter] and [http.ResponseController].
|
||||
// We need them to make them private in [Writer] so they can't be accessed directly.
|
||||
type httpResponseWriter = http.ResponseWriter
|
||||
type httpResponseController = *http.ResponseController
|
||||
|
||||
// Writer is an implementation of [http.ResponseWriter] with additional
|
||||
// functionality for managing response headers.
|
||||
type Writer struct {
|
||||
httpResponseWriter
|
||||
httpResponseController
|
||||
|
||||
config *Config // Configuration for the writer
|
||||
originHeaders http.Header // Original response headers
|
||||
result http.Header // Headers to be written to the response
|
||||
maxAge int // Current max age for Cache-Control header
|
||||
|
||||
beforeWriteOnce sync.Once
|
||||
}
|
||||
|
||||
// HTTPResponseWriter returns the underlying http.ResponseWriter.
|
||||
func (w *Writer) HTTPResponseWriter() http.ResponseWriter {
|
||||
return w.httpResponseWriter
|
||||
}
|
||||
|
||||
// SetHTTPResponseWriter replaces the underlying http.ResponseWriter.
|
||||
func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) {
|
||||
w.httpResponseWriter = rw
|
||||
w.httpResponseController = http.NewResponseController(rw)
|
||||
}
|
||||
|
||||
// SetOriginHeaders sets the origin headers for the request.
|
||||
func (w *Writer) SetOriginHeaders(h http.Header) {
|
||||
w.originHeaders = h
|
||||
}
|
||||
|
||||
// SetIsFallbackImage sets the Fallback-Image header to
|
||||
// indicate that the fallback image was used.
|
||||
func (w *Writer) SetIsFallbackImage() {
|
||||
// We set maxAge to FallbackImageTTL if it's explicitly passed
|
||||
if w.config.FallbackImageTTL < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// However, we should not overwrite existing value if set (or greater than ours)
|
||||
if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
|
||||
w.maxAge = w.config.FallbackImageTTL
|
||||
}
|
||||
}
|
||||
|
||||
// SetExpires sets the TTL from time
|
||||
func (w *Writer) SetExpires(expires *time.Time) {
|
||||
if expires == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert current maxAge to time
|
||||
currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
|
||||
|
||||
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
|
||||
if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
|
||||
w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
|
||||
}
|
||||
}
|
||||
|
||||
// SetVary sets the Vary header
|
||||
func (w *Writer) SetVary() {
|
||||
if val := w.config.VaryValue; len(val) > 0 {
|
||||
w.result.Set(httpheaders.Vary, val)
|
||||
}
|
||||
}
|
||||
|
||||
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
|
||||
func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
|
||||
value := httpheaders.ContentDispositionValue(
|
||||
originURL,
|
||||
filename,
|
||||
ext,
|
||||
contentType,
|
||||
returnAttachment,
|
||||
)
|
||||
|
||||
if value != "" {
|
||||
w.result.Set(httpheaders.ContentDisposition, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Passthrough copies specified headers from the original response headers to the response headers.
|
||||
func (w *Writer) Passthrough(only ...string) {
|
||||
httpheaders.Copy(w.originHeaders, w.result, only)
|
||||
}
|
||||
|
||||
// CopyFrom copies specified headers from the headers object. Please note that
|
||||
// all the past operations may overwrite those values.
|
||||
func (w *Writer) CopyFrom(headers http.Header, only []string) {
|
||||
httpheaders.Copy(headers, w.result, only)
|
||||
}
|
||||
|
||||
// SetContentLength sets the Content-Length header
|
||||
func (w *Writer) SetContentLength(contentLength int) {
|
||||
if contentLength < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
|
||||
}
|
||||
|
||||
// SetContentType sets the Content-Type header
|
||||
func (w *Writer) SetContentType(mime string) {
|
||||
w.result.Set(httpheaders.ContentType, mime)
|
||||
}
|
||||
|
||||
// writeCanonical sets the Link header with the canonical URL.
|
||||
// It is mandatory for any response if enabled in the configuration.
|
||||
func (w *Writer) SetCanonical(url string) {
|
||||
if !w.config.SetCanonicalHeader {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
|
||||
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
|
||||
w.result.Set(httpheaders.Link, value)
|
||||
}
|
||||
}
|
||||
|
||||
// setCacheControl sets the Cache-Control header with the specified value.
|
||||
func (w *Writer) setCacheControl(value int) bool {
|
||||
if value <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
|
||||
return true
|
||||
}
|
||||
|
||||
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
|
||||
func (w *Writer) setCacheControlNoCache() {
|
||||
w.result.Set(httpheaders.CacheControl, "no-cache")
|
||||
}
|
||||
|
||||
// setCacheControlPassthrough sets the Cache-Control header from the request
|
||||
// if passthrough is enabled in the configuration.
|
||||
func (w *Writer) setCacheControlPassthrough() bool {
|
||||
if !w.config.CacheControlPassthrough || w.maxAge > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" {
|
||||
w.result.Set(httpheaders.CacheControl, val)
|
||||
return true
|
||||
}
|
||||
|
||||
if val := w.originHeaders.Get(httpheaders.Expires); val != "" {
|
||||
if t, err := time.Parse(http.TimeFormat, val); err == nil {
|
||||
maxAge := max(0, int(time.Until(t).Seconds()))
|
||||
return w.setCacheControl(maxAge)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// setCSP sets the Content-Security-Policy header to prevent script execution.
|
||||
func (w *Writer) setCSP() {
|
||||
w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
|
||||
}
|
||||
|
||||
// flushHeaders writes the headers to the response writer. It does not overwrite
|
||||
// target headers, which were set outside the header writer.
|
||||
func (w *Writer) flushHeaders() {
|
||||
// Then, let's try to set Cache-Control using priority order
|
||||
switch {
|
||||
case w.setCacheControl(w.maxAge): // First, try set explicit
|
||||
case w.setCacheControlPassthrough(): // Try to pick up from request headers
|
||||
case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value
|
||||
default:
|
||||
w.setCacheControlNoCache() // By default we use no-cache
|
||||
}
|
||||
|
||||
w.setCSP()
|
||||
|
||||
// Copy all headers to the response without overwriting existing ones
|
||||
httpheaders.CopyAll(w.result, w.Header(), false)
|
||||
}
|
||||
|
||||
// beforeWrite is called before [WriteHeader] and [Write]
|
||||
func (w *Writer) beforeWrite() {
|
||||
w.beforeWriteOnce.Do(func() {
|
||||
// We're going to start writing response.
|
||||
// Set write deadline.
|
||||
w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout))
|
||||
|
||||
// Flush headers before we write anything
|
||||
w.flushHeaders()
|
||||
})
|
||||
}
|
||||
|
||||
// WriteHeader writes the HTTP response header.
|
||||
//
|
||||
// It ensures that all headers are flushed before writing the status code.
|
||||
func (w *Writer) WriteHeader(statusCode int) {
|
||||
w.beforeWrite()
|
||||
|
||||
w.httpResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Write writes the HTTP response body.
|
||||
//
|
||||
// It ensures that all headers are flushed before writing the body.
|
||||
func (w *Writer) Write(b []byte) (int, error) {
|
||||
w.beforeWrite()
|
||||
|
||||
return w.httpResponseWriter.Write(b)
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package headerwriter
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type HeaderWriterSuite struct {
|
||||
type ResponseWriterSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
@@ -22,16 +22,18 @@ type writerTestCase struct {
|
||||
req http.Header
|
||||
res http.Header
|
||||
config Config
|
||||
fn func(*Request)
|
||||
fn func(*Writer)
|
||||
}
|
||||
|
||||
func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
func (s *ResponseWriterSuite) TestHeaderCases() {
|
||||
expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC)
|
||||
expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds()))
|
||||
|
||||
shortExpires := time.Now().Add(10 * time.Second)
|
||||
shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds()))
|
||||
|
||||
writeResponseTimeout := 10 * time.Second
|
||||
|
||||
tt := []writerTestCase{
|
||||
{
|
||||
name: "MinimalHeaders",
|
||||
@@ -44,8 +46,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
SetCanonicalHeader: false,
|
||||
DefaultTTL: 0,
|
||||
CacheControlPassthrough: false,
|
||||
EnableClientHints: false,
|
||||
SetVaryAccept: false,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -60,6 +61,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
CacheControlPassthrough: true,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -74,6 +76,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
CacheControlPassthrough: true,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -88,6 +91,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
CacheControlPassthrough: true,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -101,8 +105,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
SetCanonicalHeader: true,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetCanonical("https://example.com/image.jpg")
|
||||
},
|
||||
},
|
||||
@@ -116,6 +121,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
SetCanonicalHeader: true,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -128,8 +134,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
SetCanonicalHeader: false,
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetCanonical("https://example.com/image.jpg")
|
||||
},
|
||||
},
|
||||
@@ -143,8 +150,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
DefaultTTL: 3600,
|
||||
FallbackImageTTL: 1,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetIsFallbackImage()
|
||||
},
|
||||
},
|
||||
@@ -157,8 +165,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
},
|
||||
config: Config{
|
||||
DefaultTTL: math.MaxInt32,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetExpires(&expires)
|
||||
},
|
||||
},
|
||||
@@ -172,8 +181,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
config: Config{
|
||||
DefaultTTL: math.MaxInt32,
|
||||
FallbackImageTTL: 600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetIsFallbackImage()
|
||||
w.SetExpires(&shortExpires)
|
||||
},
|
||||
@@ -187,10 +197,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{
|
||||
EnableClientHints: true,
|
||||
SetVaryAccept: true,
|
||||
VaryValue: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetVary()
|
||||
},
|
||||
},
|
||||
@@ -204,8 +214,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{},
|
||||
fn: func(w *Request) {
|
||||
config: Config{
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Writer) {
|
||||
w.Passthrough("X-Test")
|
||||
},
|
||||
},
|
||||
@@ -217,8 +229,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{},
|
||||
fn: func(w *Request) {
|
||||
config: Config{
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Writer) {
|
||||
h := http.Header{}
|
||||
h.Set("X-From", "baz")
|
||||
w.CopyFrom(h, []string{"X-From"})
|
||||
@@ -232,8 +246,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{},
|
||||
fn: func(w *Request) {
|
||||
config: Config{
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Writer) {
|
||||
w.SetContentLength(123)
|
||||
},
|
||||
},
|
||||
@@ -245,8 +261,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{},
|
||||
fn: func(w *Request) {
|
||||
config: Config{
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Writer) {
|
||||
w.SetContentType("image/png")
|
||||
},
|
||||
},
|
||||
@@ -259,57 +277,29 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
},
|
||||
config: Config{
|
||||
DefaultTTL: 3600,
|
||||
WriteResponseTimeout: writeResponseTimeout,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
fn: func(w *Writer) {
|
||||
w.SetExpires(nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WriteVaryAcceptOnly",
|
||||
req: http.Header{},
|
||||
res: http.Header{
|
||||
httpheaders.Vary: []string{"Accept"},
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{
|
||||
SetVaryAccept: true,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
w.SetVary()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WriteVaryClientHintsOnly",
|
||||
req: http.Header{},
|
||||
res: http.Header{
|
||||
httpheaders.Vary: []string{"Sec-CH-DPR, DPR, Sec-CH-Width, Width"},
|
||||
httpheaders.CacheControl: []string{"no-cache"},
|
||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||
},
|
||||
config: Config{
|
||||
EnableClientHints: true,
|
||||
},
|
||||
fn: func(w *Request) {
|
||||
w.SetVary()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
s.Run(tc.name, func() {
|
||||
factory, err := New(&tc.config)
|
||||
factory, err := NewFactory(&tc.config)
|
||||
s.Require().NoError(err)
|
||||
|
||||
writer := factory.NewRequest()
|
||||
r := httptest.NewRecorder()
|
||||
|
||||
writer := factory.NewWriter(r)
|
||||
writer.SetOriginHeaders(tc.req)
|
||||
|
||||
if tc.fn != nil {
|
||||
tc.fn(writer)
|
||||
}
|
||||
|
||||
r := httptest.NewRecorder()
|
||||
writer.Write(r)
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
s.Require().Equal(tc.res, r.Header())
|
||||
})
|
||||
@@ -317,5 +307,5 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
||||
}
|
||||
|
||||
func TestHeaderWriter(t *testing.T) {
|
||||
suite.Run(t, new(HeaderWriterSuite))
|
||||
suite.Run(t, new(ResponseWriterSuite))
|
||||
}
|
@@ -11,6 +11,7 @@ import (
|
||||
nanoid "github.com/matoous/go-nanoid/v2"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,8 +24,10 @@ var (
|
||||
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
||||
)
|
||||
|
||||
type ResponseWriter = *responsewriter.Writer
|
||||
|
||||
// RouteHandler is a function that handles HTTP requests.
|
||||
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
|
||||
type RouteHandler func(string, ResponseWriter, *http.Request) error
|
||||
|
||||
// Middleware is a function that wraps a RouteHandler with additional functionality.
|
||||
type Middleware func(next RouteHandler) RouteHandler
|
||||
@@ -40,6 +43,9 @@ type route struct {
|
||||
|
||||
// Router is responsible for routing HTTP requests
|
||||
type Router struct {
|
||||
// Response writers factory
|
||||
rwFactory *responsewriter.Factory
|
||||
|
||||
// config represents the server configuration
|
||||
config *Config
|
||||
|
||||
@@ -53,7 +59,15 @@ func NewRouter(config *Config) (*Router, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Router{config: config}, nil
|
||||
rwf, err := responsewriter.NewFactory(&config.ResponseWriter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Router{
|
||||
rwFactory: rwf,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// add adds an abitary route to the router
|
||||
@@ -114,8 +128,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req, timeoutCancel := startRequestTimer(req)
|
||||
defer timeoutCancel()
|
||||
|
||||
// Create the response writer which times out on write
|
||||
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
|
||||
// Create the [ResponseWriter]
|
||||
rww := r.rwFactory.NewWriter(rw)
|
||||
|
||||
// Get/create request ID
|
||||
reqID := r.getRequestID(req)
|
||||
@@ -123,8 +137,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Replace request IP from headers
|
||||
r.replaceRemoteAddr(req)
|
||||
|
||||
rw.Header().Set(httpheaders.Server, defaultServerName)
|
||||
rw.Header().Set(httpheaders.XRequestID, reqID)
|
||||
rww.Header().Set(httpheaders.Server, defaultServerName)
|
||||
rww.Header().Set(httpheaders.XRequestID, reqID)
|
||||
|
||||
for _, rr := range r.routes {
|
||||
if !rr.isMatch(req) {
|
||||
@@ -138,18 +152,18 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
LogRequest(reqID, req)
|
||||
}
|
||||
|
||||
rr.handler(reqID, rw, req)
|
||||
rr.handler(reqID, rww, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Means that we have not found matching route
|
||||
LogRequest(reqID, req)
|
||||
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
|
||||
r.NotFoundHandler(reqID, rw, req)
|
||||
r.NotFoundHandler(reqID, rww, req)
|
||||
}
|
||||
|
||||
// NotFoundHandler is default 404 handler
|
||||
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
func (r *Router) NotFoundHandler(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||
@@ -158,7 +172,7 @@ func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http
|
||||
}
|
||||
|
||||
// OkHandler is a default 200 OK handler
|
||||
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
func (r *Router) OkHandler(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||
|
@@ -30,7 +30,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
||||
var capturedMethod string
|
||||
var capturedPath string
|
||||
|
||||
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
@@ -38,7 +38,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
||||
return nil
|
||||
}
|
||||
|
||||
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
@@ -46,7 +46,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
||||
return nil
|
||||
}
|
||||
|
||||
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
@@ -114,20 +114,20 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
|
||||
var order []string
|
||||
|
||||
middleware1 := func(next RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "middleware1")
|
||||
return next(reqID, rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
middleware2 := func(next RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "middleware2")
|
||||
return next(reqID, rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "handler")
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
@@ -146,7 +146,7 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
|
||||
|
||||
// TestServeHTTP tests ServeHTTP method
|
||||
func (s *RouterTestSuite) TestServeHTTP() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set("Custom-Header", "test-value")
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("success"))
|
||||
@@ -169,7 +169,7 @@ func (s *RouterTestSuite) TestServeHTTP() {
|
||||
|
||||
// TestRequestID checks request ID generation and validation
|
||||
func (s *RouterTestSuite) TestRequestID() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func (s *RouterTestSuite) TestRequestID() {
|
||||
|
||||
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
|
||||
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
@@ -229,7 +229,7 @@ func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
||||
// Test IP address handling
|
||||
func (s *RouterTestSuite) TestReplaceIP() {
|
||||
var capturedRemoteAddr string
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
capturedRemoteAddr = req.RemoteAddr
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
@@ -298,7 +298,7 @@ func (s *RouterTestSuite) TestReplaceIP() {
|
||||
// TestRouteOrder checks exact/non-exact insertion order
|
||||
func (s *RouterTestSuite) TestRouteOrder() {
|
||||
|
||||
h := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
h := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -29,10 +29,14 @@ func (s *ServerTestSuite) SetupTest() {
|
||||
s.blankRouter = r
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
|
||||
return s.blankRouter.rwFactory.NewWriter(rw)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
|
||||
ctx, cancel := context.WithCancel(s.T().Context())
|
||||
|
||||
@@ -121,7 +125,7 @@ func (s *ServerTestSuite) TestWithCORS() {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
|
||||
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
|
||||
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
|
||||
@@ -170,7 +174,7 @@ func (s *ServerTestSuite) TestWithSecret() {
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
err = wrappedHandler("test-req-id", rw, req)
|
||||
err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
|
||||
if tt.expectError {
|
||||
s.Require().Error(err)
|
||||
@@ -182,7 +186,7 @@ func (s *ServerTestSuite) TestWithSecret() {
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoSuccess() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
@@ -192,14 +196,14 @@ func (s *ServerTestSuite) TestIntoSuccess() {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
|
||||
s.Equal(http.StatusOK, rw.Code)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoWithError() {
|
||||
testError := errors.New("test error")
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
return testError
|
||||
}
|
||||
|
||||
@@ -208,7 +212,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
|
||||
s.Equal(http.StatusInternalServerError, rw.Code)
|
||||
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
|
||||
@@ -216,7 +220,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||
testError := errors.New("panic error")
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
panic(testError)
|
||||
}
|
||||
|
||||
@@ -226,7 +230,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.NotPanics(func() {
|
||||
err := wrappedHandler("test-req-id", rw, req)
|
||||
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
s.Require().Error(err, "panic error")
|
||||
})
|
||||
|
||||
@@ -234,7 +238,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
@@ -245,12 +249,12 @@ func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
||||
|
||||
// Should re-panic with ErrAbortHandler
|
||||
s.Panics(func() {
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||
panic("string panic")
|
||||
}
|
||||
|
||||
@@ -261,7 +265,7 @@ func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
||||
|
||||
// Should re-panic with non-error panics
|
||||
s.NotPanics(func() {
|
||||
err := wrappedHandler("test-req-id", rw, req)
|
||||
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||
s.Require().Error(err, "string panic")
|
||||
})
|
||||
}
|
||||
|
@@ -1,47 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// timeoutResponse manages response writer with timeout. It has
|
||||
// timeout on all write methods.
|
||||
type timeoutResponse struct {
|
||||
http.ResponseWriter
|
||||
controller *http.ResponseController
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newTimeoutResponse creates a new timeoutResponse
|
||||
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
|
||||
return &timeoutResponse{
|
||||
ResponseWriter: rw,
|
||||
controller: http.NewResponseController(rw),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements http.ResponseWriter.Write
|
||||
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
rw.withWriteDeadline(func() {
|
||||
n, err = rw.ResponseWriter.Write(b)
|
||||
})
|
||||
return n, err
|
||||
}
|
||||
|
||||
// withWriteDeadline executes a Write* function with a deadline
|
||||
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
||||
deadline := time.Now().Add(rw.timeout)
|
||||
|
||||
// Set write deadline
|
||||
rw.controller.SetWriteDeadline(deadline)
|
||||
|
||||
// Reset write deadline after method has finished
|
||||
defer rw.controller.SetWriteDeadline(time.Time{})
|
||||
f()
|
||||
}
|
Reference in New Issue
Block a user