Rebuild headerwriter to server.ResponseWriter

This commit is contained in:
DarthSim
2025-09-11 12:04:04 +03:00
committed by Sergei Aleksandrovich
parent 53645688fb
commit 1f6d007948
24 changed files with 720 additions and 631 deletions

View File

@@ -6,7 +6,6 @@ import (
"github.com/imgproxy/imgproxy/v3/fetcher"
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/semaphores"
"github.com/imgproxy/imgproxy/v3/server"
)
@@ -19,7 +18,6 @@ type HandlerConfigs struct {
// Config represents an instance configuration
type Config struct {
HeaderWriter headerwriter.Config
Semaphores semaphores.Config
FallbackImage auximageprovider.StaticConfig
WatermarkImage auximageprovider.StaticConfig
@@ -31,7 +29,6 @@ type Config struct {
// NewDefaultConfig creates a new default configuration
func NewDefaultConfig() Config {
return Config{
HeaderWriter: headerwriter.NewDefaultConfig(),
Semaphores: semaphores.NewDefaultConfig(),
FallbackImage: auximageprovider.NewDefaultStaticConfig(),
WatermarkImage: auximageprovider.NewDefaultStaticConfig(),
@@ -62,10 +59,6 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
return nil, err
}
if _, err = headerwriter.LoadConfigFromEnv(&c.HeaderWriter); err != nil {
return nil, err
}
if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil {
return nil, err
}

View File

@@ -22,7 +22,7 @@ func New() *Handler {
// Execute handles the health request
func (h *Handler) Execute(
reqID string,
rw http.ResponseWriter,
rw server.ResponseWriter,
req *http.Request,
) error {
var (

View File

@@ -6,11 +6,18 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
)
func TestHealthHandler(t *testing.T) {
// Create responsewriter.Factory
rwConf := responsewriter.NewDefaultConfig()
rwf, err := responsewriter.NewFactory(&rwConf)
require.NoError(t, err)
// Create a ResponseRecorder to record the response
rr := httptest.NewRecorder()
@@ -18,7 +25,7 @@ func TestHealthHandler(t *testing.T) {
h := New()
// Call the handler function directly (no need for actual HTTP request)
h.Execute("test-req-id", rr, nil)
h.Execute("test-req-id", rwf.NewWriter(rr), nil)
// Check that we get a valid response (either 200 or 500 depending on vips state)
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/server"
)
//go:embed body.html
@@ -21,7 +22,7 @@ func New() *Handler {
// Execute handles the landing request
func (h *Handler) Execute(
reqID string,
rw http.ResponseWriter,
rw server.ResponseWriter,
req *http.Request,
) error {
rw.Header().Set(httpheaders.ContentType, "text/html")

View File

@@ -9,7 +9,6 @@ import (
"github.com/imgproxy/imgproxy/v3/errorreport"
"github.com/imgproxy/imgproxy/v3/handlers"
"github.com/imgproxy/imgproxy/v3/handlers/stream"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/monitoring"
@@ -17,11 +16,11 @@ import (
"github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/security"
"github.com/imgproxy/imgproxy/v3/semaphores"
"github.com/imgproxy/imgproxy/v3/server"
)
// HandlerContext provides access to shared handler dependencies
type HandlerContext interface {
HeaderWriter() *headerwriter.Writer
Semaphores() *semaphores.Semaphores
FallbackImage() auximageprovider.Provider
WatermarkImage() auximageprovider.Provider
@@ -56,7 +55,7 @@ func New(
// Execute handles the image processing request
func (h *Handler) Execute(
reqID string,
rw http.ResponseWriter,
rw server.ResponseWriter,
req *http.Request,
) error {
// Increment the number of requests in progress
@@ -86,7 +85,6 @@ func (h *Handler) Execute(
po: po,
imageURL: imageURL,
monitoringMeta: mm,
hwr: h.HeaderWriter().NewRequest(),
}
return hReq.execute(ctx)

View File

@@ -7,7 +7,6 @@ import (
"github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/handlers"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/imagetype"
"github.com/imgproxy/imgproxy/v3/monitoring"
@@ -23,12 +22,11 @@ type request struct {
reqID string
req *http.Request
rw http.ResponseWriter
rw server.ResponseWriter
config *Config
po *options.ProcessingOptions
imageURL string
monitoringMeta monitoring.Meta
hwr *headerwriter.Request
}
// execute handles the actual processing logic
@@ -84,13 +82,13 @@ func (r *request) execute(ctx context.Context) error {
var nmErr fetcher.NotModifiedError
if errors.As(err, &nmErr) {
r.hwr.SetOriginHeaders(nmErr.Headers())
r.rw.SetOriginHeaders(nmErr.Headers())
return r.respondWithNotModified()
}
// Prepare to write image response headers
r.hwr.SetOriginHeaders(originHeaders)
r.rw.SetOriginHeaders(originHeaders)
// If error is not related to NotModified, respond with fallback image and replace image data
if err != nil {
@@ -123,7 +121,7 @@ func (r *request) execute(ctx context.Context) error {
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing))
}
// Write debug headers. It seems unlogical to move they to headerwriter since they're
// Write debug headers. It seems unlogical to move they to responsewriter since they're
// not used anywhere else.
err = r.writeDebugHeaders(result, originData)
if err != nil {

View File

@@ -123,8 +123,8 @@ func (r *request) handleDownloadError(
headers.Del(httpheaders.Expires)
headers.Del(httpheaders.LastModified)
r.hwr.SetOriginHeaders(headers)
r.hwr.SetIsFallbackImage()
r.rw.SetOriginHeaders(headers)
r.rw.SetIsFallbackImage()
return data, statusCode, nil
}
@@ -186,19 +186,17 @@ func (r *request) writeDebugHeaders(result *processing.Result, originData imaged
// respondWithNotModified writes not-modified response
func (r *request) respondWithNotModified() error {
r.hwr.SetExpires(r.po.Expires)
r.hwr.SetVary()
r.rw.SetExpires(r.po.Expires)
r.rw.SetVary()
if r.config.LastModifiedEnabled {
r.hwr.Passthrough(httpheaders.LastModified)
r.rw.Passthrough(httpheaders.LastModified)
}
if r.config.ETagEnabled {
r.hwr.Passthrough(httpheaders.Etag)
r.rw.Passthrough(httpheaders.Etag)
}
r.hwr.Write(r.rw)
r.rw.WriteHeader(http.StatusNotModified)
server.LogResponse(
@@ -221,29 +219,27 @@ func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageDat
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize))
}
r.hwr.SetContentType(resultData.Format().Mime())
r.hwr.SetContentLength(resultSize)
r.hwr.SetContentDisposition(
r.rw.SetContentType(resultData.Format().Mime())
r.rw.SetContentLength(resultSize)
r.rw.SetContentDisposition(
r.imageURL,
r.po.Filename,
resultData.Format().Ext(),
"",
r.po.ReturnAttachment,
)
r.hwr.SetExpires(r.po.Expires)
r.hwr.SetVary()
r.hwr.SetCanonical(r.imageURL)
r.rw.SetExpires(r.po.Expires)
r.rw.SetVary()
r.rw.SetCanonical(r.imageURL)
if r.config.LastModifiedEnabled {
r.hwr.Passthrough(httpheaders.LastModified)
r.rw.Passthrough(httpheaders.LastModified)
}
if r.config.ETagEnabled {
r.hwr.Passthrough(httpheaders.Etag)
r.rw.Passthrough(httpheaders.Etag)
}
r.hwr.Write(r.rw)
r.rw.WriteHeader(statusCode)
_, err = io.Copy(r.rw, resultData.Reader())

View File

@@ -6,16 +6,16 @@ import (
"net/http"
"sync"
log "github.com/sirupsen/logrus"
"github.com/imgproxy/imgproxy/v3/cookies"
"github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/monitoring"
"github.com/imgproxy/imgproxy/v3/monitoring/stats"
"github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/server"
log "github.com/sirupsen/logrus"
)
const (
@@ -37,7 +37,6 @@ var (
type Handler struct {
config *Config // Configuration for the streamer
fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching
hw *headerwriter.Writer // Configured HeaderWriter instance
}
// request holds the parameters and state for a single streaming request
@@ -47,12 +46,11 @@ type request struct {
imageURL string
reqID string
po *options.ProcessingOptions
rw http.ResponseWriter
hw *headerwriter.Request
rw server.ResponseWriter
}
// New creates new handler object
func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Handler, error) {
func New(config *Config, fetcher *fetcher.Fetcher) (*Handler, error) {
if err := config.Validate(); err != nil {
return nil, err
}
@@ -60,7 +58,6 @@ func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Ha
return &Handler{
fetcher: fetcher,
config: config,
hw: hw,
}, nil
}
@@ -71,7 +68,7 @@ func (s *Handler) Execute(
imageURL string,
reqID string,
po *options.ProcessingOptions,
rw http.ResponseWriter,
rw server.ResponseWriter,
) error {
stream := &request{
handler: s,
@@ -80,7 +77,6 @@ func (s *Handler) Execute(
reqID: reqID,
po: po,
rw: rw,
hw: s.hw.NewRequest(),
}
return stream.execute(ctx)
@@ -118,17 +114,14 @@ func (s *request) execute(ctx context.Context) error {
}
// Output streaming response headers
s.hw.SetOriginHeaders(res.Header)
s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
s.hw.SetContentLength(int(res.ContentLength))
s.hw.SetCanonical(s.imageURL)
s.hw.SetExpires(s.po.Expires)
s.rw.SetOriginHeaders(res.Header)
s.rw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
s.rw.SetContentLength(int(res.ContentLength))
s.rw.SetCanonical(s.imageURL)
s.rw.SetExpires(s.po.Expires)
// Set the Content-Disposition header
s.setContentDisposition(r.URL().Path, res, s.hw)
// Write headers from writer
s.hw.Write(s.rw)
s.setContentDisposition(r.URL().Path, res)
// Copy the status code from the original response
s.rw.WriteHeader(res.StatusCode)
@@ -158,7 +151,7 @@ func (s *request) getImageRequestHeaders() http.Header {
}
// setContentDisposition writes the headers to the response writer
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response, hw *headerwriter.Request) {
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response) {
// Try to set correct Content-Disposition file name and extension
if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 {
return
@@ -166,7 +159,7 @@ func (s *request) setContentDisposition(imagePath string, serverResponse *http.R
ct := serverResponse.Header.Get(httpheaders.ContentType)
hw.SetContentDisposition(
s.rw.SetContentDisposition(
imagePath,
s.po.Filename,
"",

View File

@@ -1,7 +1,6 @@
package stream
import (
"context"
"fmt"
"io"
"net/http"
@@ -17,9 +16,10 @@ import (
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/fetcher"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
"github.com/imgproxy/imgproxy/v3/testutil"
)
const (
@@ -27,14 +27,54 @@ const (
)
type HandlerTestSuite struct {
suite.Suite
handler *Handler
testutil.LazySuite
rwConf testutil.LazyObj[*responsewriter.Config]
rwFactory testutil.LazyObj[*responsewriter.Factory]
config testutil.LazyObj[*Config]
handler testutil.LazyObj[*Handler]
}
func (s *HandlerTestSuite) SetupSuite() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
s.rwConf, _ = testutil.NewLazySuiteObj(
s,
func() (*responsewriter.Config, error) {
c := responsewriter.NewDefaultConfig()
return &c, nil
},
)
s.rwFactory, _ = testutil.NewLazySuiteObj(
s,
func() (*responsewriter.Factory, error) {
return responsewriter.NewFactory(s.rwConf())
},
)
s.config, _ = testutil.NewLazySuiteObj(
s,
func() (*Config, error) {
c := NewDefaultConfig()
return &c, nil
},
)
s.handler, _ = testutil.NewLazySuiteObj(
s,
func() (*Handler, error) {
fc := fetcher.NewDefaultConfig()
fetcher, err := fetcher.New(&fc)
s.Require().NoError(err)
return New(s.config(), fetcher)
},
)
// Silence logs during tests
logrus.SetOutput(io.Discard)
}
@@ -47,21 +87,11 @@ func (s *HandlerTestSuite) TearDownSuite() {
func (s *HandlerTestSuite) SetupTest() {
config.Reset()
config.AllowLoopbackSourceAddresses = true
}
fc := fetcher.NewDefaultConfig()
fetcher, err := fetcher.New(&fc)
s.Require().NoError(err)
cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hw, err := headerwriter.New(&hwc)
s.Require().NoError(err)
h, err := New(&cfg, hw, fetcher)
s.Require().NoError(err)
s.handler = h
func (s *HandlerTestSuite) SetupSubTest() {
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
s.ResetLazyObjects()
}
func (s *HandlerTestSuite) readTestFile(name string) []byte {
@@ -70,6 +100,24 @@ func (s *HandlerTestSuite) readTestFile(name string) []byte {
return data
}
func (s *HandlerTestSuite) execute(
imageURL string,
header http.Header,
po *options.ProcessingOptions,
) *httptest.ResponseRecorder {
req := httptest.NewRequest("GET", "/", nil)
httpheaders.CopyAll(header, req.Header, true)
ctx := s.T().Context()
rw := httptest.NewRecorder()
rww := s.rwFactory().NewWriter(rw)
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
s.Require().NoError(err)
return rw
}
// TestHandlerBasicRequest checks basic streaming request
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
data := s.readTestFile("test1.png")
@@ -81,12 +129,7 @@ func (s *HandlerTestSuite) TestHandlerBasicRequest() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -114,12 +157,7 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -148,16 +186,12 @@ func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(httpheaders.IfNoneMatch, etag)
req.Header.Set(httpheaders.AcceptEncoding, "gzip")
req.Header.Set(httpheaders.Range, "bytes=*")
h := make(http.Header)
h.Set(httpheaders.IfNoneMatch, etag)
h.Set(httpheaders.AcceptEncoding, "gzip")
h.Set(httpheaders.Range, "bytes=*")
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -175,8 +209,6 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{
Filename: "custom_name",
ReturnAttachment: true,
@@ -184,8 +216,7 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
// Use a URL with a .png extension to help content disposition logic
imageURL := ts.URL + "/test.png"
err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(imageURL, nil, po)
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -342,25 +373,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
}))
defer ts.Close()
fc, err := fetcher.LoadConfigFromEnv(nil)
s.Require().NoError(err)
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
s.rwConf().DefaultTTL = 4242
fetcher, err := fetcher.New(fc)
s.Require().NoError(err)
cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hwc.CacheControlPassthrough = tc.cacheControlPassthrough
hwc.DefaultTTL = 4242
hw, err := headerwriter.New(&hwc)
s.Require().NoError(err)
handler, err := New(&cfg, hw, fetcher)
s.Require().NoError(err)
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
if tc.timestampOffset != nil {
@@ -368,8 +383,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
po.Expires = &expires
}
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, po)
res := rw.Result()
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
@@ -400,12 +414,7 @@ func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -420,12 +429,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(404, res.StatusCode)
@@ -433,21 +437,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
fc, err := fetcher.LoadConfigFromEnv(nil)
s.Require().NoError(err)
fetcher, err := fetcher.New(fc)
s.Require().NoError(err)
cfg := NewDefaultConfig()
cfg.CookiePassthrough = true
hwc := headerwriter.NewDefaultConfig()
hw, err := headerwriter.New(&hwc)
s.Require().NoError(err)
handler, err := New(&cfg, hw, fetcher)
s.Require().NoError(err)
s.config().CookiePassthrough = true
data := s.readTestFile("test1.png")
@@ -464,13 +454,10 @@ func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
}))
defer ts.Close()
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
h := make(http.Header)
h.Set(httpheaders.Cookie, "test_cookie=test_value")
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)
@@ -488,29 +475,9 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
defer ts.Close()
for _, sc := range []bool{true, false} {
fc, err := fetcher.LoadConfigFromEnv(nil)
s.Require().NoError(err)
s.rwConf().SetCanonicalHeader = sc
fetcher, err := fetcher.New(fc)
s.Require().NoError(err)
cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hwc.SetCanonicalHeader = sc
hw, err := headerwriter.New(&hwc)
s.Require().NoError(err)
handler, err := New(&cfg, hw, fetcher)
s.Require().NoError(err)
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
po := &options.ProcessingOptions{}
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
s.Require().NoError(err)
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
res := rw.Result()
s.Require().Equal(200, res.StatusCode)

View File

@@ -1,62 +0,0 @@
package headerwriter
import (
"fmt"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/ensure"
)
// Config is the package-local configuration
type Config struct {
SetCanonicalHeader bool // Indicates whether to set the canonical header
DefaultTTL int // Default Cache-Control max-age= value for cached images
FallbackImageTTL int // TTL for images served as fallbacks
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
EnableClientHints bool // Enable Vary header
SetVaryAccept bool // Whether to include Accept in Vary header
}
// NewDefaultConfig returns a new Config instance with default values.
func NewDefaultConfig() Config {
return Config{
SetCanonicalHeader: false,
DefaultTTL: 31536000,
FallbackImageTTL: 0,
CacheControlPassthrough: false,
EnableClientHints: false,
SetVaryAccept: false,
}
}
// LoadConfigFromEnv overrides configuration variables from environment
func LoadConfigFromEnv(c *Config) (*Config, error) {
c = ensure.Ensure(c, NewDefaultConfig)
c.SetCanonicalHeader = config.SetCanonicalHeader
c.DefaultTTL = config.TTL
c.FallbackImageTTL = config.FallbackImageTTL
c.CacheControlPassthrough = config.CacheControlPassthrough
c.EnableClientHints = config.EnableClientHints
c.SetVaryAccept = config.AutoWebp ||
config.EnforceWebp ||
config.AutoAvif ||
config.EnforceAvif ||
config.AutoJxl ||
config.EnforceJxl
return c, nil
}
// Validate checks config for errors
func (c *Config) Validate() error {
if c.DefaultTTL < 0 {
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
}
if c.FallbackImageTTL < 0 {
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
}
return nil
}

View File

@@ -1,214 +0,0 @@
// headerwriter is responsible for writing processing/stream response headers
package headerwriter
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
// Writer is a struct that creates header writer factories.
type Writer struct {
config *Config
varyValue string
}
// Request is a private struct that builds HTTP response headers for a specific request.
type Request struct {
writer *Writer
originHeaders http.Header // Original response headers
result http.Header // Headers to be written to the response
maxAge int // Current max age for Cache-Control header
}
// New creates a new header writer factory with the provided config.
func New(config *Config) (*Writer, error) {
if err := config.Validate(); err != nil {
return nil, err
}
vary := make([]string, 0)
if config.SetVaryAccept {
vary = append(vary, "Accept")
}
if config.EnableClientHints {
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
}
varyValue := strings.Join(vary, ", ")
return &Writer{
config: config,
varyValue: varyValue,
}, nil
}
// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL.
func (w *Writer) NewRequest() *Request {
return &Request{
writer: w,
result: make(http.Header),
maxAge: -1,
originHeaders: make(http.Header),
}
}
// SetOriginHeaders sets the origin headers for the request.
func (r *Request) SetOriginHeaders(h http.Header) {
r.originHeaders = h
}
// SetIsFallbackImage sets the Fallback-Image header to
// indicate that the fallback image was used.
func (r *Request) SetIsFallbackImage() {
// We set maxAge to FallbackImageTTL if it's explicitly passed
if r.writer.config.FallbackImageTTL < 0 {
return
}
// However, we should not overwrite existing value if set (or greater than ours)
if r.maxAge < 0 || r.maxAge > r.writer.config.FallbackImageTTL {
r.maxAge = r.writer.config.FallbackImageTTL
}
}
// SetExpires sets the TTL from time
func (r *Request) SetExpires(expires *time.Time) {
if expires == nil {
return
}
// Convert current maxAge to time
currentMaxAgeTime := time.Now().Add(time.Duration(r.maxAge) * time.Second)
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
if r.maxAge < 0 || expires.Before(currentMaxAgeTime) {
r.maxAge = min(r.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
}
}
// SetVary sets the Vary header
func (r *Request) SetVary() {
if len(r.writer.varyValue) > 0 {
r.result.Set(httpheaders.Vary, r.writer.varyValue)
}
}
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
func (r *Request) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
value := httpheaders.ContentDispositionValue(
originURL,
filename,
ext,
contentType,
returnAttachment,
)
if value != "" {
r.result.Set(httpheaders.ContentDisposition, value)
}
}
// Passthrough copies specified headers from the original response headers to the response headers.
func (r *Request) Passthrough(only ...string) {
httpheaders.Copy(r.originHeaders, r.result, only)
}
// CopyFrom copies specified headers from the headers object. Please note that
// all the past operations may overwrite those values.
func (r *Request) CopyFrom(headers http.Header, only []string) {
httpheaders.Copy(headers, r.result, only)
}
// SetContentLength sets the Content-Length header
func (r *Request) SetContentLength(contentLength int) {
if contentLength < 0 {
return
}
r.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
}
// SetContentType sets the Content-Type header
func (r *Request) SetContentType(mime string) {
r.result.Set(httpheaders.ContentType, mime)
}
// writeCanonical sets the Link header with the canonical URL.
// It is mandatory for any response if enabled in the configuration.
func (r *Request) SetCanonical(url string) {
if !r.writer.config.SetCanonicalHeader {
return
}
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
r.result.Set(httpheaders.Link, value)
}
}
// setCacheControl sets the Cache-Control header with the specified value.
func (r *Request) setCacheControl(value int) bool {
if value <= 0 {
return false
}
r.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
return true
}
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
func (r *Request) setCacheControlNoCache() {
r.result.Set(httpheaders.CacheControl, "no-cache")
}
// setCacheControlPassthrough sets the Cache-Control header from the request
// if passthrough is enabled in the configuration.
func (r *Request) setCacheControlPassthrough() bool {
if !r.writer.config.CacheControlPassthrough || r.maxAge > 0 {
return false
}
if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" {
r.result.Set(httpheaders.CacheControl, val)
return true
}
if val := r.originHeaders.Get(httpheaders.Expires); val != "" {
if t, err := time.Parse(http.TimeFormat, val); err == nil {
maxAge := max(0, int(time.Until(t).Seconds()))
return r.setCacheControl(maxAge)
}
}
return false
}
// setCSP sets the Content-Security-Policy header to prevent script execution.
func (r *Request) setCSP() {
r.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
}
// Write writes the headers to the response writer. It does not overwrite
// target headers, which were set outside the header writer.
func (r *Request) Write(rw http.ResponseWriter) {
// Then, let's try to set Cache-Control using priority order
switch {
case r.setCacheControl(r.maxAge): // First, try set explicit
case r.setCacheControlPassthrough(): // Try to pick up from request headers
case r.setCacheControl(r.writer.config.DefaultTTL): // Fallback to default value
default:
r.setCacheControlNoCache() // By default we use no-cache
}
r.setCSP()
// Copy all headers to the response without overwriting existing ones
httpheaders.CopyAll(r.result, rw.Header(), false)
}

View File

@@ -11,7 +11,6 @@ import (
landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing"
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
"github.com/imgproxy/imgproxy/v3/headerwriter"
"github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/memory"
"github.com/imgproxy/imgproxy/v3/monitoring/prometheus"
@@ -34,7 +33,6 @@ type ImgproxyHandlers struct {
// Imgproxy holds all the components needed for imgproxy to function
type Imgproxy struct {
headerWriter *headerwriter.Writer
semaphores *semaphores.Semaphores
fallbackImage auximageprovider.Provider
watermarkImage auximageprovider.Provider
@@ -46,11 +44,6 @@ type Imgproxy struct {
// New creates a new imgproxy instance
func New(ctx context.Context, config *Config) (*Imgproxy, error) {
headerWriter, err := headerwriter.New(&config.HeaderWriter)
if err != nil {
return nil, err
}
fetcher, err := fetcher.New(&config.Fetcher)
if err != nil {
return nil, err
@@ -74,7 +67,6 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
}
imgproxy := &Imgproxy{
headerWriter: headerWriter,
semaphores: semaphores,
fallbackImage: fallbackImage,
watermarkImage: watermarkImage,
@@ -86,7 +78,7 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
imgproxy.handlers.Health = healthhandler.New()
imgproxy.handlers.Landing = landinghandler.New()
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, headerWriter, fetcher)
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, fetcher)
if err != nil {
return nil, err
}
@@ -180,10 +172,6 @@ func (i *Imgproxy) startMemoryTicker(ctx context.Context) {
}
}
func (i *Imgproxy) HeaderWriter() *headerwriter.Writer {
return i.headerWriter
}
func (i *Imgproxy) Semaphores() *semaphores.Semaphores {
return i.semaphores
}

View File

@@ -238,7 +238,7 @@ func (s *ProcessingHandlerTestSuite) TestErrorSavingToSVG() {
}
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
s.Config().HeaderWriter.CacheControlPassthrough = true
s.Config().Server.ResponseWriter.CacheControlPassthrough = true
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public")

View File

@@ -8,6 +8,7 @@ import (
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/ensure"
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
)
// Config represents HTTP server config
@@ -18,7 +19,6 @@ type Config struct {
PathPrefix string // Path prefix for the server
MaxClients int // Maximum number of concurrent clients
ReadRequestTimeout time.Duration // Timeout for reading requests
WriteResponseTimeout time.Duration // Timeout for writing responses
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
GracefulTimeout time.Duration // Timeout for graceful shutdown
CORSAllowOrigin string // CORS allowed origin
@@ -27,6 +27,8 @@ type Config struct {
SocketReusePort bool // Enable SO_REUSEPORT socket option
HealthCheckPath string // Health check path from config
ResponseWriter responsewriter.Config // Response writer config
// TODO: We are not sure where to put it yet
FreeMemoryInterval time.Duration // Interval for freeing memory
LogMemStats bool // Log memory stats
@@ -41,7 +43,6 @@ func NewDefaultConfig() Config {
MaxClients: 2048,
ReadRequestTimeout: 10 * time.Second,
KeepAliveTimeout: 10 * time.Second,
WriteResponseTimeout: 10 * time.Second,
GracefulTimeout: 20 * time.Second,
CORSAllowOrigin: "",
Secret: "",
@@ -50,6 +51,8 @@ func NewDefaultConfig() Config {
HealthCheckPath: "",
FreeMemoryInterval: 10 * time.Second,
LogMemStats: false,
ResponseWriter: responsewriter.NewDefaultConfig(),
}
}
@@ -72,6 +75,10 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second
c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0
if _, err := responsewriter.LoadConfigFromEnv(&c.ResponseWriter); err != nil {
return nil, err
}
return c, nil
}
@@ -89,10 +96,6 @@ func (c *Config) Validate() error {
return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
}
if c.WriteResponseTimeout <= 0 {
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
}
if c.KeepAliveTimeout < 0 {
return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
}

View File

@@ -23,10 +23,13 @@ func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
return h
}
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
ctx, cancel, rw := monitoring.StartRequest(req.Context(), rw, req)
return func(reqID string, rw ResponseWriter, req *http.Request) error {
ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
defer cancel()
// Replace rw.ResponseWriter with new one returned from monitoring
rw.SetHTTPResponseWriter(newRw)
return h(reqID, rw, req.WithContext(ctx))
}
}
@@ -37,7 +40,7 @@ func (r *Router) WithCORS(h RouteHandler) RouteHandler {
return h
}
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return func(reqID string, rw ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
@@ -53,7 +56,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return func(reqID string, rw ResponseWriter, req *http.Request) error {
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
return h(reqID, rw, req)
} else {
@@ -64,7 +67,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
// WithPanic recovers panic and converts it to normal error
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
defer func() {
// try to recover from panic
rerr := recover()
@@ -94,7 +97,7 @@ func (r *Router) WithPanic(h RouteHandler) RouteHandler {
// WithReportError handles error reporting.
// It should be placed after `WithMonitoring`, but before `WithPanic`.
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return func(reqID string, rw ResponseWriter, req *http.Request) error {
// Open the error context
ctx := errorreport.StartRequest(req)
req = req.WithContext(ctx)

View File

@@ -0,0 +1,87 @@
package responsewriter
import (
"fmt"
"strings"
"time"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/ensure"
)
// Config holds configuration for response writer
type Config struct {
SetCanonicalHeader bool // Indicates whether to set the canonical header
DefaultTTL int // Default Cache-Control max-age= value for cached images
FallbackImageTTL int // TTL for images served as fallbacks
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
VaryValue string // Value for Vary header
WriteResponseTimeout time.Duration // Timeout for response write operations
}
// NewDefaultConfig returns a new Config instance with default values.
func NewDefaultConfig() Config {
return Config{
SetCanonicalHeader: false,
DefaultTTL: 31536000,
FallbackImageTTL: 0,
CacheControlPassthrough: false,
VaryValue: "",
WriteResponseTimeout: 10 * time.Second,
}
}
// LoadConfigFromEnv overrides configuration variables from environment
func LoadConfigFromEnv(c *Config) (*Config, error) {
c = ensure.Ensure(c, NewDefaultConfig)
c.SetCanonicalHeader = config.SetCanonicalHeader
c.DefaultTTL = config.TTL
c.FallbackImageTTL = config.FallbackImageTTL
c.CacheControlPassthrough = config.CacheControlPassthrough
c.WriteResponseTimeout = time.Duration(config.WriteResponseTimeout) * time.Second
vary := make([]string, 0)
if c.envEnableFormatDetection() {
vary = append(vary, "Accept")
}
if c.envEnableClientHints() {
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
}
c.VaryValue = strings.Join(vary, ", ")
return c, nil
}
func (c *Config) envEnableFormatDetection() bool {
return config.AutoWebp ||
config.EnforceWebp ||
config.AutoAvif ||
config.EnforceAvif ||
config.AutoJxl ||
config.EnforceJxl
}
func (c *Config) envEnableClientHints() bool {
return config.EnableClientHints
}
// Validate checks config for errors
func (c *Config) Validate() error {
if c.DefaultTTL < 0 {
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
}
if c.FallbackImageTTL < 0 {
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
}
if c.WriteResponseTimeout <= 0 {
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
}
return nil
}

View File

@@ -0,0 +1,114 @@
package responsewriter
import (
"fmt"
"io"
"os"
"testing"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
)
type ResponseWriterConfigSuite struct {
suite.Suite
}
func (s *ResponseWriterConfigSuite) SetupSuite() {
logrus.SetOutput(io.Discard)
}
func (s *ResponseWriterConfigSuite) TearDownSuite() {
logrus.SetOutput(os.Stdout)
}
func (s *ResponseWriterConfigSuite) TestLoadingVaryValueFromEnv() {
defaultEnv := map[string]string{
"IMGPROXY_AUTO_WEBP": "",
"IMGPROXY_ENFORCE_WEBP": "",
"IMGPROXY_AUTO_AVIF": "",
"IMGPROXY_ENFORCE_AVIF": "",
"IMGPROXY_AUTO_JXL": "",
"IMGPROXY_ENFORCE_JXL": "",
"IMGPROXY_ENABLE_CLIENT_HINTS": "",
}
testCases := []struct {
name string
env map[string]string
expected string
}{
{
name: "AutoWebP",
env: map[string]string{"IMGPROXY_AUTO_WEBP": "true"},
expected: "Accept",
},
{
name: "EnforceWebP",
env: map[string]string{"IMGPROXY_ENFORCE_WEBP": "true"},
expected: "Accept",
},
{
name: "AutoAVIF",
env: map[string]string{"IMGPROXY_AUTO_AVIF": "true"},
expected: "Accept",
},
{
name: "EnforceAVIF",
env: map[string]string{"IMGPROXY_ENFORCE_AVIF": "true"},
expected: "Accept",
},
{
name: "AutoJXL",
env: map[string]string{"IMGPROXY_AUTO_JXL": "true"},
expected: "Accept",
},
{
name: "EnforceJXL",
env: map[string]string{"IMGPROXY_ENFORCE_JXL": "true"},
expected: "Accept",
},
{
name: "EnableClientHints",
env: map[string]string{"IMGPROXY_ENABLE_CLIENT_HINTS": "true"},
expected: "Sec-CH-DPR, DPR, Sec-CH-Width, Width",
},
{
name: "Combined",
env: map[string]string{
"IMGPROXY_AUTO_WEBP": "true",
"IMGPROXY_ENABLE_CLIENT_HINTS": "true",
},
expected: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
},
}
for _, tc := range testCases {
s.Run(fmt.Sprintf("%v", tc.env), func() {
// Set default environment variables
for key, value := range defaultEnv {
s.T().Setenv(key, value)
}
// Set environment variables
for key, value := range tc.env {
s.T().Setenv(key, value)
}
// TODO: Remove when we removed global config
config.Reset()
config.Configure()
// Load config
cfg, err := LoadConfigFromEnv(nil)
// Assert expected values
s.Require().NoError(err)
s.Require().Equal(tc.expected, cfg.VaryValue)
})
}
}
func TestResponseWriterConfig(t *testing.T) {
suite.Run(t, new(ResponseWriterConfigSuite))
}

View File

@@ -0,0 +1,30 @@
package responsewriter
import "net/http"
// Factory is a struct that creates response writers.
type Factory struct {
config *Config
}
func NewFactory(config *Config) (*Factory, error) {
if err := config.Validate(); err != nil {
return nil, err
}
return &Factory{config}, nil
}
// NewWriter wraps [http.ResponseWriter] into [Writer].
func (f *Factory) NewWriter(rw http.ResponseWriter) *Writer {
w := &Writer{
config: f.config,
result: make(http.Header),
originHeaders: make(http.Header),
maxAge: -1,
}
w.SetHTTPResponseWriter(rw)
return w
}

View File

@@ -0,0 +1,226 @@
package responsewriter
import (
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
// Just aliases for [http.ResponseWriter] and [http.ResponseController].
// We need them to make them private in [Writer] so they can't be accessed directly.
type httpResponseWriter = http.ResponseWriter
type httpResponseController = *http.ResponseController
// Writer is an implementation of [http.ResponseWriter] with additional
// functionality for managing response headers.
type Writer struct {
httpResponseWriter
httpResponseController
config *Config // Configuration for the writer
originHeaders http.Header // Original response headers
result http.Header // Headers to be written to the response
maxAge int // Current max age for Cache-Control header
beforeWriteOnce sync.Once
}
// HTTPResponseWriter returns the underlying http.ResponseWriter.
func (w *Writer) HTTPResponseWriter() http.ResponseWriter {
return w.httpResponseWriter
}
// SetHTTPResponseWriter replaces the underlying http.ResponseWriter.
func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) {
w.httpResponseWriter = rw
w.httpResponseController = http.NewResponseController(rw)
}
// SetOriginHeaders sets the origin headers for the request.
func (w *Writer) SetOriginHeaders(h http.Header) {
w.originHeaders = h
}
// SetIsFallbackImage sets the Fallback-Image header to
// indicate that the fallback image was used.
func (w *Writer) SetIsFallbackImage() {
// We set maxAge to FallbackImageTTL if it's explicitly passed
if w.config.FallbackImageTTL < 0 {
return
}
// However, we should not overwrite existing value if set (or greater than ours)
if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
w.maxAge = w.config.FallbackImageTTL
}
}
// SetExpires sets the TTL from time
func (w *Writer) SetExpires(expires *time.Time) {
if expires == nil {
return
}
// Convert current maxAge to time
currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
}
}
// SetVary sets the Vary header
func (w *Writer) SetVary() {
if val := w.config.VaryValue; len(val) > 0 {
w.result.Set(httpheaders.Vary, val)
}
}
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
value := httpheaders.ContentDispositionValue(
originURL,
filename,
ext,
contentType,
returnAttachment,
)
if value != "" {
w.result.Set(httpheaders.ContentDisposition, value)
}
}
// Passthrough copies specified headers from the original response headers to the response headers.
func (w *Writer) Passthrough(only ...string) {
httpheaders.Copy(w.originHeaders, w.result, only)
}
// CopyFrom copies specified headers from the headers object. Please note that
// all the past operations may overwrite those values.
func (w *Writer) CopyFrom(headers http.Header, only []string) {
httpheaders.Copy(headers, w.result, only)
}
// SetContentLength sets the Content-Length header
func (w *Writer) SetContentLength(contentLength int) {
if contentLength < 0 {
return
}
w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
}
// SetContentType sets the Content-Type header
func (w *Writer) SetContentType(mime string) {
w.result.Set(httpheaders.ContentType, mime)
}
// writeCanonical sets the Link header with the canonical URL.
// It is mandatory for any response if enabled in the configuration.
func (w *Writer) SetCanonical(url string) {
if !w.config.SetCanonicalHeader {
return
}
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
w.result.Set(httpheaders.Link, value)
}
}
// setCacheControl sets the Cache-Control header with the specified value.
func (w *Writer) setCacheControl(value int) bool {
if value <= 0 {
return false
}
w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
return true
}
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
func (w *Writer) setCacheControlNoCache() {
w.result.Set(httpheaders.CacheControl, "no-cache")
}
// setCacheControlPassthrough sets the Cache-Control header from the request
// if passthrough is enabled in the configuration.
func (w *Writer) setCacheControlPassthrough() bool {
if !w.config.CacheControlPassthrough || w.maxAge > 0 {
return false
}
if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" {
w.result.Set(httpheaders.CacheControl, val)
return true
}
if val := w.originHeaders.Get(httpheaders.Expires); val != "" {
if t, err := time.Parse(http.TimeFormat, val); err == nil {
maxAge := max(0, int(time.Until(t).Seconds()))
return w.setCacheControl(maxAge)
}
}
return false
}
// setCSP sets the Content-Security-Policy header to prevent script execution.
func (w *Writer) setCSP() {
w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
}
// flushHeaders writes the headers to the response writer. It does not overwrite
// target headers, which were set outside the header writer.
func (w *Writer) flushHeaders() {
// Then, let's try to set Cache-Control using priority order
switch {
case w.setCacheControl(w.maxAge): // First, try set explicit
case w.setCacheControlPassthrough(): // Try to pick up from request headers
case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value
default:
w.setCacheControlNoCache() // By default we use no-cache
}
w.setCSP()
// Copy all headers to the response without overwriting existing ones
httpheaders.CopyAll(w.result, w.Header(), false)
}
// beforeWrite is called before [WriteHeader] and [Write]
func (w *Writer) beforeWrite() {
w.beforeWriteOnce.Do(func() {
// We're going to start writing response.
// Set write deadline.
w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout))
// Flush headers before we write anything
w.flushHeaders()
})
}
// WriteHeader writes the HTTP response header.
//
// It ensures that all headers are flushed before writing the status code.
func (w *Writer) WriteHeader(statusCode int) {
w.beforeWrite()
w.httpResponseWriter.WriteHeader(statusCode)
}
// Write writes the HTTP response body.
//
// It ensures that all headers are flushed before writing the body.
func (w *Writer) Write(b []byte) (int, error) {
w.beforeWrite()
return w.httpResponseWriter.Write(b)
}

View File

@@ -1,4 +1,4 @@
package headerwriter
package responsewriter
import (
"fmt"
@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/suite"
)
type HeaderWriterSuite struct {
type ResponseWriterSuite struct {
suite.Suite
}
@@ -22,16 +22,18 @@ type writerTestCase struct {
req http.Header
res http.Header
config Config
fn func(*Request)
fn func(*Writer)
}
func (s *HeaderWriterSuite) TestHeaderCases() {
func (s *ResponseWriterSuite) TestHeaderCases() {
expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC)
expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds()))
shortExpires := time.Now().Add(10 * time.Second)
shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds()))
writeResponseTimeout := 10 * time.Second
tt := []writerTestCase{
{
name: "MinimalHeaders",
@@ -44,8 +46,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
SetCanonicalHeader: false,
DefaultTTL: 0,
CacheControlPassthrough: false,
EnableClientHints: false,
SetVaryAccept: false,
WriteResponseTimeout: writeResponseTimeout,
},
},
{
@@ -60,6 +61,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
CacheControlPassthrough: true,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
},
{
@@ -74,6 +76,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
CacheControlPassthrough: true,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
},
{
@@ -88,6 +91,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
CacheControlPassthrough: true,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
},
{
@@ -101,8 +105,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
SetCanonicalHeader: true,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetCanonical("https://example.com/image.jpg")
},
},
@@ -116,6 +121,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
SetCanonicalHeader: true,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
},
{
@@ -128,8 +134,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
SetCanonicalHeader: false,
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetCanonical("https://example.com/image.jpg")
},
},
@@ -143,8 +150,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
DefaultTTL: 3600,
FallbackImageTTL: 1,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetIsFallbackImage()
},
},
@@ -157,8 +165,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
},
config: Config{
DefaultTTL: math.MaxInt32,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetExpires(&expires)
},
},
@@ -172,8 +181,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{
DefaultTTL: math.MaxInt32,
FallbackImageTTL: 600,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetIsFallbackImage()
w.SetExpires(&shortExpires)
},
@@ -187,10 +197,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{
EnableClientHints: true,
SetVaryAccept: true,
VaryValue: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetVary()
},
},
@@ -204,8 +214,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{},
fn: func(w *Request) {
config: Config{
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Writer) {
w.Passthrough("X-Test")
},
},
@@ -217,8 +229,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{},
fn: func(w *Request) {
config: Config{
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Writer) {
h := http.Header{}
h.Set("X-From", "baz")
w.CopyFrom(h, []string{"X-From"})
@@ -232,8 +246,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{},
fn: func(w *Request) {
config: Config{
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Writer) {
w.SetContentLength(123)
},
},
@@ -245,8 +261,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{},
fn: func(w *Request) {
config: Config{
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Writer) {
w.SetContentType("image/png")
},
},
@@ -259,57 +277,29 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
},
config: Config{
DefaultTTL: 3600,
WriteResponseTimeout: writeResponseTimeout,
},
fn: func(w *Request) {
fn: func(w *Writer) {
w.SetExpires(nil)
},
},
{
name: "WriteVaryAcceptOnly",
req: http.Header{},
res: http.Header{
httpheaders.Vary: []string{"Accept"},
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{
SetVaryAccept: true,
},
fn: func(w *Request) {
w.SetVary()
},
},
{
name: "WriteVaryClientHintsOnly",
req: http.Header{},
res: http.Header{
httpheaders.Vary: []string{"Sec-CH-DPR, DPR, Sec-CH-Width, Width"},
httpheaders.CacheControl: []string{"no-cache"},
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
},
config: Config{
EnableClientHints: true,
},
fn: func(w *Request) {
w.SetVary()
},
},
}
for _, tc := range tt {
s.Run(tc.name, func() {
factory, err := New(&tc.config)
factory, err := NewFactory(&tc.config)
s.Require().NoError(err)
writer := factory.NewRequest()
r := httptest.NewRecorder()
writer := factory.NewWriter(r)
writer.SetOriginHeaders(tc.req)
if tc.fn != nil {
tc.fn(writer)
}
r := httptest.NewRecorder()
writer.Write(r)
writer.WriteHeader(http.StatusOK)
s.Require().Equal(tc.res, r.Header())
})
@@ -317,5 +307,5 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
}
func TestHeaderWriter(t *testing.T) {
suite.Run(t, new(HeaderWriterSuite))
suite.Run(t, new(ResponseWriterSuite))
}

View File

@@ -11,6 +11,7 @@ import (
nanoid "github.com/matoous/go-nanoid/v2"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
)
const (
@@ -23,8 +24,10 @@ var (
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
)
type ResponseWriter = *responsewriter.Writer
// RouteHandler is a function that handles HTTP requests.
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
type RouteHandler func(string, ResponseWriter, *http.Request) error
// Middleware is a function that wraps a RouteHandler with additional functionality.
type Middleware func(next RouteHandler) RouteHandler
@@ -40,6 +43,9 @@ type route struct {
// Router is responsible for routing HTTP requests
type Router struct {
// Response writers factory
rwFactory *responsewriter.Factory
// config represents the server configuration
config *Config
@@ -53,7 +59,15 @@ func NewRouter(config *Config) (*Router, error) {
return nil, err
}
return &Router{config: config}, nil
rwf, err := responsewriter.NewFactory(&config.ResponseWriter)
if err != nil {
return nil, err
}
return &Router{
rwFactory: rwf,
config: config,
}, nil
}
// add adds an abitary route to the router
@@ -114,8 +128,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req, timeoutCancel := startRequestTimer(req)
defer timeoutCancel()
// Create the response writer which times out on write
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
// Create the [ResponseWriter]
rww := r.rwFactory.NewWriter(rw)
// Get/create request ID
reqID := r.getRequestID(req)
@@ -123,8 +137,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Replace request IP from headers
r.replaceRemoteAddr(req)
rw.Header().Set(httpheaders.Server, defaultServerName)
rw.Header().Set(httpheaders.XRequestID, reqID)
rww.Header().Set(httpheaders.Server, defaultServerName)
rww.Header().Set(httpheaders.XRequestID, reqID)
for _, rr := range r.routes {
if !rr.isMatch(req) {
@@ -138,18 +152,18 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
LogRequest(reqID, req)
}
rr.handler(reqID, rw, req)
rr.handler(reqID, rww, req)
return
}
// Means that we have not found matching route
LogRequest(reqID, req)
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
r.NotFoundHandler(reqID, rw, req)
r.NotFoundHandler(reqID, rww, req)
}
// NotFoundHandler is default 404 handler
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
func (r *Router) NotFoundHandler(reqID string, rw ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.WriteHeader(http.StatusNotFound)
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
@@ -158,7 +172,7 @@ func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http
}
// OkHandler is a default 200 OK handler
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
func (r *Router) OkHandler(reqID string, rw ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.WriteHeader(http.StatusOK)
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy

View File

@@ -30,7 +30,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
var capturedMethod string
var capturedPath string
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
@@ -38,7 +38,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
return nil
}
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
@@ -46,7 +46,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
return nil
}
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
@@ -114,20 +114,20 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
var order []string
middleware1 := func(next RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return func(reqID string, rw ResponseWriter, req *http.Request) error {
order = append(order, "middleware1")
return next(reqID, rw, req)
}
}
middleware2 := func(next RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return func(reqID string, rw ResponseWriter, req *http.Request) error {
order = append(order, "middleware2")
return next(reqID, rw, req)
}
}
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
order = append(order, "handler")
rw.WriteHeader(200)
return nil
@@ -146,7 +146,7 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
// TestServeHTTP tests ServeHTTP method
func (s *RouterTestSuite) TestServeHTTP() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
rw.Header().Set("Custom-Header", "test-value")
rw.WriteHeader(200)
rw.Write([]byte("success"))
@@ -169,7 +169,7 @@ func (s *RouterTestSuite) TestServeHTTP() {
// TestRequestID checks request ID generation and validation
func (s *RouterTestSuite) TestRequestID() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
rw.WriteHeader(200)
return nil
}
@@ -209,7 +209,7 @@ func (s *RouterTestSuite) TestRequestID() {
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
rw.WriteHeader(200)
return nil
}
@@ -229,7 +229,7 @@ func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
// Test IP address handling
func (s *RouterTestSuite) TestReplaceIP() {
var capturedRemoteAddr string
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
capturedRemoteAddr = req.RemoteAddr
rw.WriteHeader(200)
return nil
@@ -298,7 +298,7 @@ func (s *RouterTestSuite) TestReplaceIP() {
// TestRouteOrder checks exact/non-exact insertion order
func (s *RouterTestSuite) TestRouteOrder() {
h := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
h := func(reqID string, rw ResponseWriter, req *http.Request) error {
return nil
}

View File

@@ -29,10 +29,14 @@ func (s *ServerTestSuite) SetupTest() {
s.blankRouter = r
}
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
return nil
}
func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
return s.blankRouter.rwFactory.NewWriter(rw)
}
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
ctx, cancel := context.WithCancel(s.T().Context())
@@ -121,7 +125,7 @@ func (s *ServerTestSuite) TestWithCORS() {
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
wrappedHandler("test-req-id", s.wrapRW(rw), req)
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
@@ -170,7 +174,7 @@ func (s *ServerTestSuite) TestWithSecret() {
}
rw := httptest.NewRecorder()
err = wrappedHandler("test-req-id", rw, req)
err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
if tt.expectError {
s.Require().Error(err)
@@ -182,7 +186,7 @@ func (s *ServerTestSuite) TestWithSecret() {
}
func (s *ServerTestSuite) TestIntoSuccess() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
rw.WriteHeader(http.StatusOK)
return nil
}
@@ -192,14 +196,14 @@ func (s *ServerTestSuite) TestIntoSuccess() {
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
wrappedHandler("test-req-id", s.wrapRW(rw), req)
s.Equal(http.StatusOK, rw.Code)
}
func (s *ServerTestSuite) TestIntoWithError() {
testError := errors.New("test error")
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
return testError
}
@@ -208,7 +212,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
wrappedHandler("test-req-id", s.wrapRW(rw), req)
s.Equal(http.StatusInternalServerError, rw.Code)
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
@@ -216,7 +220,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
func (s *ServerTestSuite) TestIntoPanicWithError() {
testError := errors.New("panic error")
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
panic(testError)
}
@@ -226,7 +230,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
rw := httptest.NewRecorder()
s.NotPanics(func() {
err := wrappedHandler("test-req-id", rw, req)
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
s.Require().Error(err, "panic error")
})
@@ -234,7 +238,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
}
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
panic(http.ErrAbortHandler)
}
@@ -245,12 +249,12 @@ func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
// Should re-panic with ErrAbortHandler
s.Panics(func() {
wrappedHandler("test-req-id", rw, req)
wrappedHandler("test-req-id", s.wrapRW(rw), req)
})
}
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
panic("string panic")
}
@@ -261,7 +265,7 @@ func (s *ServerTestSuite) TestIntoPanicWithNonError() {
// Should re-panic with non-error panics
s.NotPanics(func() {
err := wrappedHandler("test-req-id", rw, req)
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
s.Require().Error(err, "string panic")
})
}

View File

@@ -1,47 +0,0 @@
package server
import (
"net/http"
"time"
)
// timeoutResponse manages response writer with timeout. It has
// timeout on all write methods.
type timeoutResponse struct {
http.ResponseWriter
controller *http.ResponseController
timeout time.Duration
}
// newTimeoutResponse creates a new timeoutResponse
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
return &timeoutResponse{
ResponseWriter: rw,
controller: http.NewResponseController(rw),
timeout: timeout,
}
}
// Write implements http.ResponseWriter.Write
func (rw *timeoutResponse) Write(b []byte) (int, error) {
var (
n int
err error
)
rw.withWriteDeadline(func() {
n, err = rw.ResponseWriter.Write(b)
})
return n, err
}
// withWriteDeadline executes a Write* function with a deadline
func (rw *timeoutResponse) withWriteDeadline(f func()) {
deadline := time.Now().Add(rw.timeout)
// Set write deadline
rw.controller.SetWriteDeadline(deadline)
// Reset write deadline after method has finished
defer rw.controller.SetWriteDeadline(time.Time{})
f()
}