From 1f6d007948d7ac2b76cac7f0679bf5cc793936d5 Mon Sep 17 00:00:00 2001 From: DarthSim Date: Thu, 11 Sep 2025 12:04:04 +0300 Subject: [PATCH] Rebuild headerwriter to server.ResponseWriter --- config.go | 7 - handlers/health/handler.go | 2 +- handlers/health/handler_test.go | 9 +- handlers/landing/handler.go | 3 +- handlers/processing/handler.go | 6 +- handlers/processing/request.go | 10 +- handlers/processing/request_methods.go | 32 ++- handlers/stream/handler.go | 37 ++- handlers/stream/handler_test.go | 203 +++++++--------- headerwriter/config.go | 62 ----- headerwriter/writer.go | 214 ----------------- imgproxy.go | 14 +- integration/processing_handler_test.go | 2 +- server/config.go | 15 +- server/middlewares.go | 15 +- server/responsewriter/config.go | 87 +++++++ server/responsewriter/config_test.go | 114 +++++++++ server/responsewriter/factory.go | 30 +++ server/responsewriter/writer.go | 226 ++++++++++++++++++ .../responsewriter}/writer_test.go | 130 +++++----- server/router.go | 34 ++- server/router_test.go | 22 +- server/server_test.go | 30 ++- server/timeout_response.go | 47 ---- 24 files changed, 720 insertions(+), 631 deletions(-) delete mode 100644 headerwriter/config.go delete mode 100644 headerwriter/writer.go create mode 100644 server/responsewriter/config.go create mode 100644 server/responsewriter/config_test.go create mode 100644 server/responsewriter/factory.go create mode 100644 server/responsewriter/writer.go rename {headerwriter => server/responsewriter}/writer_test.go (77%) delete mode 100644 server/timeout_response.go diff --git a/config.go b/config.go index 84b09a65..c826f333 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,6 @@ import ( "github.com/imgproxy/imgproxy/v3/fetcher" processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing" streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/semaphores" "github.com/imgproxy/imgproxy/v3/server" ) @@ -19,7 +18,6 @@ type HandlerConfigs struct { // Config represents an instance configuration type Config struct { - HeaderWriter headerwriter.Config Semaphores semaphores.Config FallbackImage auximageprovider.StaticConfig WatermarkImage auximageprovider.StaticConfig @@ -31,7 +29,6 @@ type Config struct { // NewDefaultConfig creates a new default configuration func NewDefaultConfig() Config { return Config{ - HeaderWriter: headerwriter.NewDefaultConfig(), Semaphores: semaphores.NewDefaultConfig(), FallbackImage: auximageprovider.NewDefaultStaticConfig(), WatermarkImage: auximageprovider.NewDefaultStaticConfig(), @@ -62,10 +59,6 @@ func LoadConfigFromEnv(c *Config) (*Config, error) { return nil, err } - if _, err = headerwriter.LoadConfigFromEnv(&c.HeaderWriter); err != nil { - return nil, err - } - if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil { return nil, err } diff --git a/handlers/health/handler.go b/handlers/health/handler.go index 54c789ca..679849b6 100644 --- a/handlers/health/handler.go +++ b/handlers/health/handler.go @@ -22,7 +22,7 @@ func New() *Handler { // Execute handles the health request func (h *Handler) Execute( reqID string, - rw http.ResponseWriter, + rw server.ResponseWriter, req *http.Request, ) error { var ( diff --git a/handlers/health/handler_test.go b/handlers/health/handler_test.go index 85af23e9..266b9ace 100644 --- a/handlers/health/handler_test.go +++ b/handlers/health/handler_test.go @@ -6,11 +6,18 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/server/responsewriter" ) func TestHealthHandler(t *testing.T) { + // Create responsewriter.Factory + rwConf := responsewriter.NewDefaultConfig() + rwf, err := responsewriter.NewFactory(&rwConf) + require.NoError(t, err) + // Create a ResponseRecorder to record the response rr := httptest.NewRecorder() @@ -18,7 +25,7 @@ func TestHealthHandler(t *testing.T) { h := New() // Call the handler function directly (no need for actual HTTP request) - h.Execute("test-req-id", rr, nil) + h.Execute("test-req-id", rwf.NewWriter(rr), nil) // Check that we get a valid response (either 200 or 500 depending on vips state) assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError) diff --git a/handlers/landing/handler.go b/handlers/landing/handler.go index 04da94ae..a6adae2a 100644 --- a/handlers/landing/handler.go +++ b/handlers/landing/handler.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/server" ) //go:embed body.html @@ -21,7 +22,7 @@ func New() *Handler { // Execute handles the landing request func (h *Handler) Execute( reqID string, - rw http.ResponseWriter, + rw server.ResponseWriter, req *http.Request, ) error { rw.Header().Set(httpheaders.ContentType, "text/html") diff --git a/handlers/processing/handler.go b/handlers/processing/handler.go index 0980c407..c902872f 100644 --- a/handlers/processing/handler.go +++ b/handlers/processing/handler.go @@ -9,7 +9,6 @@ import ( "github.com/imgproxy/imgproxy/v3/errorreport" "github.com/imgproxy/imgproxy/v3/handlers" "github.com/imgproxy/imgproxy/v3/handlers/stream" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/monitoring" @@ -17,11 +16,11 @@ import ( "github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/security" "github.com/imgproxy/imgproxy/v3/semaphores" + "github.com/imgproxy/imgproxy/v3/server" ) // HandlerContext provides access to shared handler dependencies type HandlerContext interface { - HeaderWriter() *headerwriter.Writer Semaphores() *semaphores.Semaphores FallbackImage() auximageprovider.Provider WatermarkImage() auximageprovider.Provider @@ -56,7 +55,7 @@ func New( // Execute handles the image processing request func (h *Handler) Execute( reqID string, - rw http.ResponseWriter, + rw server.ResponseWriter, req *http.Request, ) error { // Increment the number of requests in progress @@ -86,7 +85,6 @@ func (h *Handler) Execute( po: po, imageURL: imageURL, monitoringMeta: mm, - hwr: h.HeaderWriter().NewRequest(), } return hReq.execute(ctx) diff --git a/handlers/processing/request.go b/handlers/processing/request.go index bacd0777..f1b0dec1 100644 --- a/handlers/processing/request.go +++ b/handlers/processing/request.go @@ -7,7 +7,6 @@ import ( "github.com/imgproxy/imgproxy/v3/fetcher" "github.com/imgproxy/imgproxy/v3/handlers" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/monitoring" @@ -23,12 +22,11 @@ type request struct { reqID string req *http.Request - rw http.ResponseWriter + rw server.ResponseWriter config *Config po *options.ProcessingOptions imageURL string monitoringMeta monitoring.Meta - hwr *headerwriter.Request } // execute handles the actual processing logic @@ -84,13 +82,13 @@ func (r *request) execute(ctx context.Context) error { var nmErr fetcher.NotModifiedError if errors.As(err, &nmErr) { - r.hwr.SetOriginHeaders(nmErr.Headers()) + r.rw.SetOriginHeaders(nmErr.Headers()) return r.respondWithNotModified() } // Prepare to write image response headers - r.hwr.SetOriginHeaders(originHeaders) + r.rw.SetOriginHeaders(originHeaders) // If error is not related to NotModified, respond with fallback image and replace image data if err != nil { @@ -123,7 +121,7 @@ func (r *request) execute(ctx context.Context) error { return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing)) } - // Write debug headers. It seems unlogical to move they to headerwriter since they're + // Write debug headers. It seems unlogical to move they to responsewriter since they're // not used anywhere else. err = r.writeDebugHeaders(result, originData) if err != nil { diff --git a/handlers/processing/request_methods.go b/handlers/processing/request_methods.go index dc0437f9..d013569f 100644 --- a/handlers/processing/request_methods.go +++ b/handlers/processing/request_methods.go @@ -123,8 +123,8 @@ func (r *request) handleDownloadError( headers.Del(httpheaders.Expires) headers.Del(httpheaders.LastModified) - r.hwr.SetOriginHeaders(headers) - r.hwr.SetIsFallbackImage() + r.rw.SetOriginHeaders(headers) + r.rw.SetIsFallbackImage() return data, statusCode, nil } @@ -186,19 +186,17 @@ func (r *request) writeDebugHeaders(result *processing.Result, originData imaged // respondWithNotModified writes not-modified response func (r *request) respondWithNotModified() error { - r.hwr.SetExpires(r.po.Expires) - r.hwr.SetVary() + r.rw.SetExpires(r.po.Expires) + r.rw.SetVary() if r.config.LastModifiedEnabled { - r.hwr.Passthrough(httpheaders.LastModified) + r.rw.Passthrough(httpheaders.LastModified) } if r.config.ETagEnabled { - r.hwr.Passthrough(httpheaders.Etag) + r.rw.Passthrough(httpheaders.Etag) } - r.hwr.Write(r.rw) - r.rw.WriteHeader(http.StatusNotModified) server.LogResponse( @@ -221,29 +219,27 @@ func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageDat return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize)) } - r.hwr.SetContentType(resultData.Format().Mime()) - r.hwr.SetContentLength(resultSize) - r.hwr.SetContentDisposition( + r.rw.SetContentType(resultData.Format().Mime()) + r.rw.SetContentLength(resultSize) + r.rw.SetContentDisposition( r.imageURL, r.po.Filename, resultData.Format().Ext(), "", r.po.ReturnAttachment, ) - r.hwr.SetExpires(r.po.Expires) - r.hwr.SetVary() - r.hwr.SetCanonical(r.imageURL) + r.rw.SetExpires(r.po.Expires) + r.rw.SetVary() + r.rw.SetCanonical(r.imageURL) if r.config.LastModifiedEnabled { - r.hwr.Passthrough(httpheaders.LastModified) + r.rw.Passthrough(httpheaders.LastModified) } if r.config.ETagEnabled { - r.hwr.Passthrough(httpheaders.Etag) + r.rw.Passthrough(httpheaders.Etag) } - r.hwr.Write(r.rw) - r.rw.WriteHeader(statusCode) _, err = io.Copy(r.rw, resultData.Reader()) diff --git a/handlers/stream/handler.go b/handlers/stream/handler.go index 594994ee..3b932940 100644 --- a/handlers/stream/handler.go +++ b/handlers/stream/handler.go @@ -6,16 +6,16 @@ import ( "net/http" "sync" + log "github.com/sirupsen/logrus" + "github.com/imgproxy/imgproxy/v3/cookies" "github.com/imgproxy/imgproxy/v3/fetcher" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/monitoring" "github.com/imgproxy/imgproxy/v3/monitoring/stats" "github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/server" - log "github.com/sirupsen/logrus" ) const ( @@ -35,9 +35,8 @@ var ( // Handler handles image passthrough requests, allowing images to be streamed directly type Handler struct { - config *Config // Configuration for the streamer - fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching - hw *headerwriter.Writer // Configured HeaderWriter instance + config *Config // Configuration for the streamer + fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching } // request holds the parameters and state for a single streaming request @@ -47,12 +46,11 @@ type request struct { imageURL string reqID string po *options.ProcessingOptions - rw http.ResponseWriter - hw *headerwriter.Request + rw server.ResponseWriter } // New creates new handler object -func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Handler, error) { +func New(config *Config, fetcher *fetcher.Fetcher) (*Handler, error) { if err := config.Validate(); err != nil { return nil, err } @@ -60,7 +58,6 @@ func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Ha return &Handler{ fetcher: fetcher, config: config, - hw: hw, }, nil } @@ -71,7 +68,7 @@ func (s *Handler) Execute( imageURL string, reqID string, po *options.ProcessingOptions, - rw http.ResponseWriter, + rw server.ResponseWriter, ) error { stream := &request{ handler: s, @@ -80,7 +77,6 @@ func (s *Handler) Execute( reqID: reqID, po: po, rw: rw, - hw: s.hw.NewRequest(), } return stream.execute(ctx) @@ -118,17 +114,14 @@ func (s *request) execute(ctx context.Context) error { } // Output streaming response headers - s.hw.SetOriginHeaders(res.Header) - s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was - s.hw.SetContentLength(int(res.ContentLength)) - s.hw.SetCanonical(s.imageURL) - s.hw.SetExpires(s.po.Expires) + s.rw.SetOriginHeaders(res.Header) + s.rw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was + s.rw.SetContentLength(int(res.ContentLength)) + s.rw.SetCanonical(s.imageURL) + s.rw.SetExpires(s.po.Expires) // Set the Content-Disposition header - s.setContentDisposition(r.URL().Path, res, s.hw) - - // Write headers from writer - s.hw.Write(s.rw) + s.setContentDisposition(r.URL().Path, res) // Copy the status code from the original response s.rw.WriteHeader(res.StatusCode) @@ -158,7 +151,7 @@ func (s *request) getImageRequestHeaders() http.Header { } // setContentDisposition writes the headers to the response writer -func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response, hw *headerwriter.Request) { +func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response) { // Try to set correct Content-Disposition file name and extension if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 { return @@ -166,7 +159,7 @@ func (s *request) setContentDisposition(imagePath string, serverResponse *http.R ct := serverResponse.Header.Get(httpheaders.ContentType) - hw.SetContentDisposition( + s.rw.SetContentDisposition( imagePath, s.po.Filename, "", diff --git a/handlers/stream/handler_test.go b/handlers/stream/handler_test.go index 1a9418f5..66c0d6c3 100644 --- a/handlers/stream/handler_test.go +++ b/handlers/stream/handler_test.go @@ -1,7 +1,6 @@ package stream import ( - "context" "fmt" "io" "net/http" @@ -17,9 +16,10 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/fetcher" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/options" + "github.com/imgproxy/imgproxy/v3/server/responsewriter" + "github.com/imgproxy/imgproxy/v3/testutil" ) const ( @@ -27,14 +27,54 @@ const ( ) type HandlerTestSuite struct { - suite.Suite - handler *Handler + testutil.LazySuite + + rwConf testutil.LazyObj[*responsewriter.Config] + rwFactory testutil.LazyObj[*responsewriter.Factory] + + config testutil.LazyObj[*Config] + handler testutil.LazyObj[*Handler] } func (s *HandlerTestSuite) SetupSuite() { config.Reset() config.AllowLoopbackSourceAddresses = true + s.rwConf, _ = testutil.NewLazySuiteObj( + s, + func() (*responsewriter.Config, error) { + c := responsewriter.NewDefaultConfig() + return &c, nil + }, + ) + + s.rwFactory, _ = testutil.NewLazySuiteObj( + s, + func() (*responsewriter.Factory, error) { + return responsewriter.NewFactory(s.rwConf()) + }, + ) + + s.config, _ = testutil.NewLazySuiteObj( + s, + func() (*Config, error) { + c := NewDefaultConfig() + return &c, nil + }, + ) + + s.handler, _ = testutil.NewLazySuiteObj( + s, + func() (*Handler, error) { + fc := fetcher.NewDefaultConfig() + + fetcher, err := fetcher.New(&fc) + s.Require().NoError(err) + + return New(s.config(), fetcher) + }, + ) + // Silence logs during tests logrus.SetOutput(io.Discard) } @@ -47,21 +87,11 @@ func (s *HandlerTestSuite) TearDownSuite() { func (s *HandlerTestSuite) SetupTest() { config.Reset() config.AllowLoopbackSourceAddresses = true +} - fc := fetcher.NewDefaultConfig() - - fetcher, err := fetcher.New(&fc) - s.Require().NoError(err) - - cfg := NewDefaultConfig() - - hwc := headerwriter.NewDefaultConfig() - hw, err := headerwriter.New(&hwc) - s.Require().NoError(err) - - h, err := New(&cfg, hw, fetcher) - s.Require().NoError(err) - s.handler = h +func (s *HandlerTestSuite) SetupSubTest() { + // We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest + s.ResetLazyObjects() } func (s *HandlerTestSuite) readTestFile(name string) []byte { @@ -70,6 +100,24 @@ func (s *HandlerTestSuite) readTestFile(name string) []byte { return data } +func (s *HandlerTestSuite) execute( + imageURL string, + header http.Header, + po *options.ProcessingOptions, +) *httptest.ResponseRecorder { + req := httptest.NewRequest("GET", "/", nil) + httpheaders.CopyAll(header, req.Header, true) + + ctx := s.T().Context() + rw := httptest.NewRecorder() + rww := s.rwFactory().NewWriter(rw) + + err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww) + s.Require().NoError(err) + + return rw +} + // TestHandlerBasicRequest checks basic streaming request func (s *HandlerTestSuite) TestHandlerBasicRequest() { data := s.readTestFile("test1.png") @@ -81,12 +129,7 @@ func (s *HandlerTestSuite) TestHandlerBasicRequest() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -114,12 +157,7 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -148,16 +186,12 @@ func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set(httpheaders.IfNoneMatch, etag) - req.Header.Set(httpheaders.AcceptEncoding, "gzip") - req.Header.Set(httpheaders.Range, "bytes=*") + h := make(http.Header) + h.Set(httpheaders.IfNoneMatch, etag) + h.Set(httpheaders.AcceptEncoding, "gzip") + h.Set(httpheaders.Range, "bytes=*") - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -175,8 +209,6 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() po := &options.ProcessingOptions{ Filename: "custom_name", ReturnAttachment: true, @@ -184,8 +216,7 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() { // Use a URL with a .png extension to help content disposition logic imageURL := ts.URL + "/test.png" - err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(imageURL, nil, po) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -342,25 +373,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { })) defer ts.Close() - fc, err := fetcher.LoadConfigFromEnv(nil) - s.Require().NoError(err) + s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough + s.rwConf().DefaultTTL = 4242 - fetcher, err := fetcher.New(fc) - s.Require().NoError(err) - - cfg := NewDefaultConfig() - hwc := headerwriter.NewDefaultConfig() - hwc.CacheControlPassthrough = tc.cacheControlPassthrough - hwc.DefaultTTL = 4242 - - hw, err := headerwriter.New(&hwc) - s.Require().NoError(err) - - handler, err := New(&cfg, hw, fetcher) - s.Require().NoError(err) - - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() po := &options.ProcessingOptions{} if tc.timestampOffset != nil { @@ -368,8 +383,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { po.Expires = &expires } - err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, po) res := rw.Result() s.Require().Equal(tc.expectedStatusCode, res.StatusCode) @@ -400,12 +414,7 @@ func (s *HandlerTestSuite) TestHandlerSecurityHeaders() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -420,12 +429,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(404, res.StatusCode) @@ -433,21 +437,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() { // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { - fc, err := fetcher.LoadConfigFromEnv(nil) - s.Require().NoError(err) - - fetcher, err := fetcher.New(fc) - s.Require().NoError(err) - - cfg := NewDefaultConfig() - cfg.CookiePassthrough = true - - hwc := headerwriter.NewDefaultConfig() - hw, err := headerwriter.New(&hwc) - s.Require().NoError(err) - - handler, err := New(&cfg, hw, fetcher) - s.Require().NoError(err) + s.config().CookiePassthrough = true data := s.readTestFile("test1.png") @@ -464,13 +454,10 @@ func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { })) defer ts.Close() - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set(httpheaders.Cookie, "test_cookie=test_value") - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} + h := make(http.Header) + h.Set(httpheaders.Cookie, "test_cookie=test_value") - err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, h, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) @@ -488,29 +475,9 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() { defer ts.Close() for _, sc := range []bool{true, false} { - fc, err := fetcher.LoadConfigFromEnv(nil) - s.Require().NoError(err) + s.rwConf().SetCanonicalHeader = sc - fetcher, err := fetcher.New(fc) - s.Require().NoError(err) - - cfg := NewDefaultConfig() - hwc := headerwriter.NewDefaultConfig() - - hwc.SetCanonicalHeader = sc - - hw, err := headerwriter.New(&hwc) - s.Require().NoError(err) - - handler, err := New(&cfg, hw, fetcher) - s.Require().NoError(err) - - req := httptest.NewRequest("GET", "/", nil) - rw := httptest.NewRecorder() - po := &options.ProcessingOptions{} - - err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw) - s.Require().NoError(err) + rw := s.execute(ts.URL, nil, &options.ProcessingOptions{}) res := rw.Result() s.Require().Equal(200, res.StatusCode) diff --git a/headerwriter/config.go b/headerwriter/config.go deleted file mode 100644 index 8b2ef5c0..00000000 --- a/headerwriter/config.go +++ /dev/null @@ -1,62 +0,0 @@ -package headerwriter - -import ( - "fmt" - - "github.com/imgproxy/imgproxy/v3/config" - "github.com/imgproxy/imgproxy/v3/ensure" -) - -// Config is the package-local configuration -type Config struct { - SetCanonicalHeader bool // Indicates whether to set the canonical header - DefaultTTL int // Default Cache-Control max-age= value for cached images - FallbackImageTTL int // TTL for images served as fallbacks - CacheControlPassthrough bool // Passthrough the Cache-Control from the original response - EnableClientHints bool // Enable Vary header - SetVaryAccept bool // Whether to include Accept in Vary header -} - -// NewDefaultConfig returns a new Config instance with default values. -func NewDefaultConfig() Config { - return Config{ - SetCanonicalHeader: false, - DefaultTTL: 31536000, - FallbackImageTTL: 0, - CacheControlPassthrough: false, - EnableClientHints: false, - SetVaryAccept: false, - } -} - -// LoadConfigFromEnv overrides configuration variables from environment -func LoadConfigFromEnv(c *Config) (*Config, error) { - c = ensure.Ensure(c, NewDefaultConfig) - - c.SetCanonicalHeader = config.SetCanonicalHeader - c.DefaultTTL = config.TTL - c.FallbackImageTTL = config.FallbackImageTTL - c.CacheControlPassthrough = config.CacheControlPassthrough - c.EnableClientHints = config.EnableClientHints - c.SetVaryAccept = config.AutoWebp || - config.EnforceWebp || - config.AutoAvif || - config.EnforceAvif || - config.AutoJxl || - config.EnforceJxl - - return c, nil -} - -// Validate checks config for errors -func (c *Config) Validate() error { - if c.DefaultTTL < 0 { - return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL) - } - - if c.FallbackImageTTL < 0 { - return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL) - } - - return nil -} diff --git a/headerwriter/writer.go b/headerwriter/writer.go deleted file mode 100644 index 62a27d6f..00000000 --- a/headerwriter/writer.go +++ /dev/null @@ -1,214 +0,0 @@ -// headerwriter is responsible for writing processing/stream response headers -package headerwriter - -import ( - "fmt" - "net/http" - "strconv" - "strings" - "time" - - "github.com/imgproxy/imgproxy/v3/httpheaders" -) - -// Writer is a struct that creates header writer factories. -type Writer struct { - config *Config - varyValue string -} - -// Request is a private struct that builds HTTP response headers for a specific request. -type Request struct { - writer *Writer - originHeaders http.Header // Original response headers - result http.Header // Headers to be written to the response - maxAge int // Current max age for Cache-Control header -} - -// New creates a new header writer factory with the provided config. -func New(config *Config) (*Writer, error) { - if err := config.Validate(); err != nil { - return nil, err - } - - vary := make([]string, 0) - - if config.SetVaryAccept { - vary = append(vary, "Accept") - } - - if config.EnableClientHints { - vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width") - } - - varyValue := strings.Join(vary, ", ") - - return &Writer{ - config: config, - varyValue: varyValue, - }, nil -} - -// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL. -func (w *Writer) NewRequest() *Request { - return &Request{ - writer: w, - result: make(http.Header), - maxAge: -1, - originHeaders: make(http.Header), - } -} - -// SetOriginHeaders sets the origin headers for the request. -func (r *Request) SetOriginHeaders(h http.Header) { - r.originHeaders = h -} - -// SetIsFallbackImage sets the Fallback-Image header to -// indicate that the fallback image was used. -func (r *Request) SetIsFallbackImage() { - // We set maxAge to FallbackImageTTL if it's explicitly passed - if r.writer.config.FallbackImageTTL < 0 { - return - } - - // However, we should not overwrite existing value if set (or greater than ours) - if r.maxAge < 0 || r.maxAge > r.writer.config.FallbackImageTTL { - r.maxAge = r.writer.config.FallbackImageTTL - } -} - -// SetExpires sets the TTL from time -func (r *Request) SetExpires(expires *time.Time) { - if expires == nil { - return - } - - // Convert current maxAge to time - currentMaxAgeTime := time.Now().Add(time.Duration(r.maxAge) * time.Second) - - // If maxAge outlives expires or was not set, we'll use expires as maxAge. - if r.maxAge < 0 || expires.Before(currentMaxAgeTime) { - r.maxAge = min(r.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds()))) - } -} - -// SetVary sets the Vary header -func (r *Request) SetVary() { - if len(r.writer.varyValue) > 0 { - r.result.Set(httpheaders.Vary, r.writer.varyValue) - } -} - -// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue -func (r *Request) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) { - value := httpheaders.ContentDispositionValue( - originURL, - filename, - ext, - contentType, - returnAttachment, - ) - - if value != "" { - r.result.Set(httpheaders.ContentDisposition, value) - } -} - -// Passthrough copies specified headers from the original response headers to the response headers. -func (r *Request) Passthrough(only ...string) { - httpheaders.Copy(r.originHeaders, r.result, only) -} - -// CopyFrom copies specified headers from the headers object. Please note that -// all the past operations may overwrite those values. -func (r *Request) CopyFrom(headers http.Header, only []string) { - httpheaders.Copy(headers, r.result, only) -} - -// SetContentLength sets the Content-Length header -func (r *Request) SetContentLength(contentLength int) { - if contentLength < 0 { - return - } - - r.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength)) -} - -// SetContentType sets the Content-Type header -func (r *Request) SetContentType(mime string) { - r.result.Set(httpheaders.ContentType, mime) -} - -// writeCanonical sets the Link header with the canonical URL. -// It is mandatory for any response if enabled in the configuration. -func (r *Request) SetCanonical(url string) { - if !r.writer.config.SetCanonicalHeader { - return - } - - if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") { - value := fmt.Sprintf(`<%s>; rel="canonical"`, url) - r.result.Set(httpheaders.Link, value) - } -} - -// setCacheControl sets the Cache-Control header with the specified value. -func (r *Request) setCacheControl(value int) bool { - if value <= 0 { - return false - } - - r.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value)) - return true -} - -// setCacheControlNoCache sets the Cache-Control header to no-cache (default). -func (r *Request) setCacheControlNoCache() { - r.result.Set(httpheaders.CacheControl, "no-cache") -} - -// setCacheControlPassthrough sets the Cache-Control header from the request -// if passthrough is enabled in the configuration. -func (r *Request) setCacheControlPassthrough() bool { - if !r.writer.config.CacheControlPassthrough || r.maxAge > 0 { - return false - } - - if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" { - r.result.Set(httpheaders.CacheControl, val) - return true - } - - if val := r.originHeaders.Get(httpheaders.Expires); val != "" { - if t, err := time.Parse(http.TimeFormat, val); err == nil { - maxAge := max(0, int(time.Until(t).Seconds())) - return r.setCacheControl(maxAge) - } - } - - return false -} - -// setCSP sets the Content-Security-Policy header to prevent script execution. -func (r *Request) setCSP() { - r.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'") -} - -// Write writes the headers to the response writer. It does not overwrite -// target headers, which were set outside the header writer. -func (r *Request) Write(rw http.ResponseWriter) { - // Then, let's try to set Cache-Control using priority order - switch { - case r.setCacheControl(r.maxAge): // First, try set explicit - case r.setCacheControlPassthrough(): // Try to pick up from request headers - case r.setCacheControl(r.writer.config.DefaultTTL): // Fallback to default value - default: - r.setCacheControlNoCache() // By default we use no-cache - } - - r.setCSP() - - // Copy all headers to the response without overwriting existing ones - httpheaders.CopyAll(r.result, rw.Header(), false) -} diff --git a/imgproxy.go b/imgproxy.go index 36d703ca..262e9432 100644 --- a/imgproxy.go +++ b/imgproxy.go @@ -11,7 +11,6 @@ import ( landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing" processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing" streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream" - "github.com/imgproxy/imgproxy/v3/headerwriter" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/memory" "github.com/imgproxy/imgproxy/v3/monitoring/prometheus" @@ -34,7 +33,6 @@ type ImgproxyHandlers struct { // Imgproxy holds all the components needed for imgproxy to function type Imgproxy struct { - headerWriter *headerwriter.Writer semaphores *semaphores.Semaphores fallbackImage auximageprovider.Provider watermarkImage auximageprovider.Provider @@ -46,11 +44,6 @@ type Imgproxy struct { // New creates a new imgproxy instance func New(ctx context.Context, config *Config) (*Imgproxy, error) { - headerWriter, err := headerwriter.New(&config.HeaderWriter) - if err != nil { - return nil, err - } - fetcher, err := fetcher.New(&config.Fetcher) if err != nil { return nil, err @@ -74,7 +67,6 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) { } imgproxy := &Imgproxy{ - headerWriter: headerWriter, semaphores: semaphores, fallbackImage: fallbackImage, watermarkImage: watermarkImage, @@ -86,7 +78,7 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) { imgproxy.handlers.Health = healthhandler.New() imgproxy.handlers.Landing = landinghandler.New() - imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, headerWriter, fetcher) + imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, fetcher) if err != nil { return nil, err } @@ -180,10 +172,6 @@ func (i *Imgproxy) startMemoryTicker(ctx context.Context) { } } -func (i *Imgproxy) HeaderWriter() *headerwriter.Writer { - return i.headerWriter -} - func (i *Imgproxy) Semaphores() *semaphores.Semaphores { return i.semaphores } diff --git a/integration/processing_handler_test.go b/integration/processing_handler_test.go index d84a459c..56d60c5b 100644 --- a/integration/processing_handler_test.go +++ b/integration/processing_handler_test.go @@ -238,7 +238,7 @@ func (s *ProcessingHandlerTestSuite) TestErrorSavingToSVG() { } func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() { - s.Config().HeaderWriter.CacheControlPassthrough = true + s.Config().Server.ResponseWriter.CacheControlPassthrough = true ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public") diff --git a/server/config.go b/server/config.go index 58a46c03..f528a055 100644 --- a/server/config.go +++ b/server/config.go @@ -8,6 +8,7 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/ensure" + "github.com/imgproxy/imgproxy/v3/server/responsewriter" ) // Config represents HTTP server config @@ -18,7 +19,6 @@ type Config struct { PathPrefix string // Path prefix for the server MaxClients int // Maximum number of concurrent clients ReadRequestTimeout time.Duration // Timeout for reading requests - WriteResponseTimeout time.Duration // Timeout for writing responses KeepAliveTimeout time.Duration // Timeout for keep-alive connections GracefulTimeout time.Duration // Timeout for graceful shutdown CORSAllowOrigin string // CORS allowed origin @@ -27,6 +27,8 @@ type Config struct { SocketReusePort bool // Enable SO_REUSEPORT socket option HealthCheckPath string // Health check path from config + ResponseWriter responsewriter.Config // Response writer config + // TODO: We are not sure where to put it yet FreeMemoryInterval time.Duration // Interval for freeing memory LogMemStats bool // Log memory stats @@ -41,7 +43,6 @@ func NewDefaultConfig() Config { MaxClients: 2048, ReadRequestTimeout: 10 * time.Second, KeepAliveTimeout: 10 * time.Second, - WriteResponseTimeout: 10 * time.Second, GracefulTimeout: 20 * time.Second, CORSAllowOrigin: "", Secret: "", @@ -50,6 +51,8 @@ func NewDefaultConfig() Config { HealthCheckPath: "", FreeMemoryInterval: 10 * time.Second, LogMemStats: false, + + ResponseWriter: responsewriter.NewDefaultConfig(), } } @@ -72,6 +75,10 @@ func LoadConfigFromEnv(c *Config) (*Config, error) { c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0 + if _, err := responsewriter.LoadConfigFromEnv(&c.ResponseWriter); err != nil { + return nil, err + } + return c, nil } @@ -89,10 +96,6 @@ func (c *Config) Validate() error { return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout) } - if c.WriteResponseTimeout <= 0 { - return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout) - } - if c.KeepAliveTimeout < 0 { return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout) } diff --git a/server/middlewares.go b/server/middlewares.go index c4b0fecb..590fc8c4 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -23,10 +23,13 @@ func (r *Router) WithMonitoring(h RouteHandler) RouteHandler { return h } - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { - ctx, cancel, rw := monitoring.StartRequest(req.Context(), rw, req) + return func(reqID string, rw ResponseWriter, req *http.Request) error { + ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req) defer cancel() + // Replace rw.ResponseWriter with new one returned from monitoring + rw.SetHTTPResponseWriter(newRw) + return h(reqID, rw, req.WithContext(ctx)) } } @@ -37,7 +40,7 @@ func (r *Router) WithCORS(h RouteHandler) RouteHandler { return h } - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return func(reqID string, rw ResponseWriter, req *http.Request) error { rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin) rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS") @@ -53,7 +56,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler { authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret) - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return func(reqID string, rw ResponseWriter, req *http.Request) error { if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 { return h(reqID, rw, req) } else { @@ -64,7 +67,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler { // WithPanic recovers panic and converts it to normal error func (r *Router) WithPanic(h RouteHandler) RouteHandler { - return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) { + return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) { defer func() { // try to recover from panic rerr := recover() @@ -94,7 +97,7 @@ func (r *Router) WithPanic(h RouteHandler) RouteHandler { // WithReportError handles error reporting. // It should be placed after `WithMonitoring`, but before `WithPanic`. func (r *Router) WithReportError(h RouteHandler) RouteHandler { - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return func(reqID string, rw ResponseWriter, req *http.Request) error { // Open the error context ctx := errorreport.StartRequest(req) req = req.WithContext(ctx) diff --git a/server/responsewriter/config.go b/server/responsewriter/config.go new file mode 100644 index 00000000..fbcdf32c --- /dev/null +++ b/server/responsewriter/config.go @@ -0,0 +1,87 @@ +package responsewriter + +import ( + "fmt" + "strings" + "time" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/ensure" +) + +// Config holds configuration for response writer +type Config struct { + SetCanonicalHeader bool // Indicates whether to set the canonical header + DefaultTTL int // Default Cache-Control max-age= value for cached images + FallbackImageTTL int // TTL for images served as fallbacks + CacheControlPassthrough bool // Passthrough the Cache-Control from the original response + VaryValue string // Value for Vary header + WriteResponseTimeout time.Duration // Timeout for response write operations +} + +// NewDefaultConfig returns a new Config instance with default values. +func NewDefaultConfig() Config { + return Config{ + SetCanonicalHeader: false, + DefaultTTL: 31536000, + FallbackImageTTL: 0, + CacheControlPassthrough: false, + VaryValue: "", + WriteResponseTimeout: 10 * time.Second, + } +} + +// LoadConfigFromEnv overrides configuration variables from environment +func LoadConfigFromEnv(c *Config) (*Config, error) { + c = ensure.Ensure(c, NewDefaultConfig) + + c.SetCanonicalHeader = config.SetCanonicalHeader + c.DefaultTTL = config.TTL + c.FallbackImageTTL = config.FallbackImageTTL + c.CacheControlPassthrough = config.CacheControlPassthrough + c.WriteResponseTimeout = time.Duration(config.WriteResponseTimeout) * time.Second + + vary := make([]string, 0) + + if c.envEnableFormatDetection() { + vary = append(vary, "Accept") + } + + if c.envEnableClientHints() { + vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width") + } + + c.VaryValue = strings.Join(vary, ", ") + + return c, nil +} + +func (c *Config) envEnableFormatDetection() bool { + return config.AutoWebp || + config.EnforceWebp || + config.AutoAvif || + config.EnforceAvif || + config.AutoJxl || + config.EnforceJxl +} + +func (c *Config) envEnableClientHints() bool { + return config.EnableClientHints +} + +// Validate checks config for errors +func (c *Config) Validate() error { + if c.DefaultTTL < 0 { + return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL) + } + + if c.FallbackImageTTL < 0 { + return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL) + } + + if c.WriteResponseTimeout <= 0 { + return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout) + } + + return nil +} diff --git a/server/responsewriter/config_test.go b/server/responsewriter/config_test.go new file mode 100644 index 00000000..3953a2f8 --- /dev/null +++ b/server/responsewriter/config_test.go @@ -0,0 +1,114 @@ +package responsewriter + +import ( + "fmt" + "io" + "os" + "testing" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" +) + +type ResponseWriterConfigSuite struct { + suite.Suite +} + +func (s *ResponseWriterConfigSuite) SetupSuite() { + logrus.SetOutput(io.Discard) +} + +func (s *ResponseWriterConfigSuite) TearDownSuite() { + logrus.SetOutput(os.Stdout) +} + +func (s *ResponseWriterConfigSuite) TestLoadingVaryValueFromEnv() { + defaultEnv := map[string]string{ + "IMGPROXY_AUTO_WEBP": "", + "IMGPROXY_ENFORCE_WEBP": "", + "IMGPROXY_AUTO_AVIF": "", + "IMGPROXY_ENFORCE_AVIF": "", + "IMGPROXY_AUTO_JXL": "", + "IMGPROXY_ENFORCE_JXL": "", + "IMGPROXY_ENABLE_CLIENT_HINTS": "", + } + + testCases := []struct { + name string + env map[string]string + expected string + }{ + { + name: "AutoWebP", + env: map[string]string{"IMGPROXY_AUTO_WEBP": "true"}, + expected: "Accept", + }, + { + name: "EnforceWebP", + env: map[string]string{"IMGPROXY_ENFORCE_WEBP": "true"}, + expected: "Accept", + }, + { + name: "AutoAVIF", + env: map[string]string{"IMGPROXY_AUTO_AVIF": "true"}, + expected: "Accept", + }, + { + name: "EnforceAVIF", + env: map[string]string{"IMGPROXY_ENFORCE_AVIF": "true"}, + expected: "Accept", + }, + { + name: "AutoJXL", + env: map[string]string{"IMGPROXY_AUTO_JXL": "true"}, + expected: "Accept", + }, + { + name: "EnforceJXL", + env: map[string]string{"IMGPROXY_ENFORCE_JXL": "true"}, + expected: "Accept", + }, + { + name: "EnableClientHints", + env: map[string]string{"IMGPROXY_ENABLE_CLIENT_HINTS": "true"}, + expected: "Sec-CH-DPR, DPR, Sec-CH-Width, Width", + }, + { + name: "Combined", + env: map[string]string{ + "IMGPROXY_AUTO_WEBP": "true", + "IMGPROXY_ENABLE_CLIENT_HINTS": "true", + }, + expected: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width", + }, + } + + for _, tc := range testCases { + s.Run(fmt.Sprintf("%v", tc.env), func() { + // Set default environment variables + for key, value := range defaultEnv { + s.T().Setenv(key, value) + } + // Set environment variables + for key, value := range tc.env { + s.T().Setenv(key, value) + } + + // TODO: Remove when we removed global config + config.Reset() + config.Configure() + + // Load config + cfg, err := LoadConfigFromEnv(nil) + + // Assert expected values + s.Require().NoError(err) + s.Require().Equal(tc.expected, cfg.VaryValue) + }) + } +} + +func TestResponseWriterConfig(t *testing.T) { + suite.Run(t, new(ResponseWriterConfigSuite)) +} diff --git a/server/responsewriter/factory.go b/server/responsewriter/factory.go new file mode 100644 index 00000000..c3e2a7f2 --- /dev/null +++ b/server/responsewriter/factory.go @@ -0,0 +1,30 @@ +package responsewriter + +import "net/http" + +// Factory is a struct that creates response writers. +type Factory struct { + config *Config +} + +func NewFactory(config *Config) (*Factory, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &Factory{config}, nil +} + +// NewWriter wraps [http.ResponseWriter] into [Writer]. +func (f *Factory) NewWriter(rw http.ResponseWriter) *Writer { + w := &Writer{ + config: f.config, + result: make(http.Header), + originHeaders: make(http.Header), + maxAge: -1, + } + + w.SetHTTPResponseWriter(rw) + + return w +} diff --git a/server/responsewriter/writer.go b/server/responsewriter/writer.go new file mode 100644 index 00000000..89a6591b --- /dev/null +++ b/server/responsewriter/writer.go @@ -0,0 +1,226 @@ +package responsewriter + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/imgproxy/imgproxy/v3/httpheaders" +) + +// Just aliases for [http.ResponseWriter] and [http.ResponseController]. +// We need them to make them private in [Writer] so they can't be accessed directly. +type httpResponseWriter = http.ResponseWriter +type httpResponseController = *http.ResponseController + +// Writer is an implementation of [http.ResponseWriter] with additional +// functionality for managing response headers. +type Writer struct { + httpResponseWriter + httpResponseController + + config *Config // Configuration for the writer + originHeaders http.Header // Original response headers + result http.Header // Headers to be written to the response + maxAge int // Current max age for Cache-Control header + + beforeWriteOnce sync.Once +} + +// HTTPResponseWriter returns the underlying http.ResponseWriter. +func (w *Writer) HTTPResponseWriter() http.ResponseWriter { + return w.httpResponseWriter +} + +// SetHTTPResponseWriter replaces the underlying http.ResponseWriter. +func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) { + w.httpResponseWriter = rw + w.httpResponseController = http.NewResponseController(rw) +} + +// SetOriginHeaders sets the origin headers for the request. +func (w *Writer) SetOriginHeaders(h http.Header) { + w.originHeaders = h +} + +// SetIsFallbackImage sets the Fallback-Image header to +// indicate that the fallback image was used. +func (w *Writer) SetIsFallbackImage() { + // We set maxAge to FallbackImageTTL if it's explicitly passed + if w.config.FallbackImageTTL < 0 { + return + } + + // However, we should not overwrite existing value if set (or greater than ours) + if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL { + w.maxAge = w.config.FallbackImageTTL + } +} + +// SetExpires sets the TTL from time +func (w *Writer) SetExpires(expires *time.Time) { + if expires == nil { + return + } + + // Convert current maxAge to time + currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second) + + // If maxAge outlives expires or was not set, we'll use expires as maxAge. + if w.maxAge < 0 || expires.Before(currentMaxAgeTime) { + w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds()))) + } +} + +// SetVary sets the Vary header +func (w *Writer) SetVary() { + if val := w.config.VaryValue; len(val) > 0 { + w.result.Set(httpheaders.Vary, val) + } +} + +// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue +func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) { + value := httpheaders.ContentDispositionValue( + originURL, + filename, + ext, + contentType, + returnAttachment, + ) + + if value != "" { + w.result.Set(httpheaders.ContentDisposition, value) + } +} + +// Passthrough copies specified headers from the original response headers to the response headers. +func (w *Writer) Passthrough(only ...string) { + httpheaders.Copy(w.originHeaders, w.result, only) +} + +// CopyFrom copies specified headers from the headers object. Please note that +// all the past operations may overwrite those values. +func (w *Writer) CopyFrom(headers http.Header, only []string) { + httpheaders.Copy(headers, w.result, only) +} + +// SetContentLength sets the Content-Length header +func (w *Writer) SetContentLength(contentLength int) { + if contentLength < 0 { + return + } + + w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength)) +} + +// SetContentType sets the Content-Type header +func (w *Writer) SetContentType(mime string) { + w.result.Set(httpheaders.ContentType, mime) +} + +// writeCanonical sets the Link header with the canonical URL. +// It is mandatory for any response if enabled in the configuration. +func (w *Writer) SetCanonical(url string) { + if !w.config.SetCanonicalHeader { + return + } + + if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") { + value := fmt.Sprintf(`<%s>; rel="canonical"`, url) + w.result.Set(httpheaders.Link, value) + } +} + +// setCacheControl sets the Cache-Control header with the specified value. +func (w *Writer) setCacheControl(value int) bool { + if value <= 0 { + return false + } + + w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value)) + return true +} + +// setCacheControlNoCache sets the Cache-Control header to no-cache (default). +func (w *Writer) setCacheControlNoCache() { + w.result.Set(httpheaders.CacheControl, "no-cache") +} + +// setCacheControlPassthrough sets the Cache-Control header from the request +// if passthrough is enabled in the configuration. +func (w *Writer) setCacheControlPassthrough() bool { + if !w.config.CacheControlPassthrough || w.maxAge > 0 { + return false + } + + if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" { + w.result.Set(httpheaders.CacheControl, val) + return true + } + + if val := w.originHeaders.Get(httpheaders.Expires); val != "" { + if t, err := time.Parse(http.TimeFormat, val); err == nil { + maxAge := max(0, int(time.Until(t).Seconds())) + return w.setCacheControl(maxAge) + } + } + + return false +} + +// setCSP sets the Content-Security-Policy header to prevent script execution. +func (w *Writer) setCSP() { + w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'") +} + +// flushHeaders writes the headers to the response writer. It does not overwrite +// target headers, which were set outside the header writer. +func (w *Writer) flushHeaders() { + // Then, let's try to set Cache-Control using priority order + switch { + case w.setCacheControl(w.maxAge): // First, try set explicit + case w.setCacheControlPassthrough(): // Try to pick up from request headers + case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value + default: + w.setCacheControlNoCache() // By default we use no-cache + } + + w.setCSP() + + // Copy all headers to the response without overwriting existing ones + httpheaders.CopyAll(w.result, w.Header(), false) +} + +// beforeWrite is called before [WriteHeader] and [Write] +func (w *Writer) beforeWrite() { + w.beforeWriteOnce.Do(func() { + // We're going to start writing response. + // Set write deadline. + w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout)) + + // Flush headers before we write anything + w.flushHeaders() + }) +} + +// WriteHeader writes the HTTP response header. +// +// It ensures that all headers are flushed before writing the status code. +func (w *Writer) WriteHeader(statusCode int) { + w.beforeWrite() + + w.httpResponseWriter.WriteHeader(statusCode) +} + +// Write writes the HTTP response body. +// +// It ensures that all headers are flushed before writing the body. +func (w *Writer) Write(b []byte) (int, error) { + w.beforeWrite() + + return w.httpResponseWriter.Write(b) +} diff --git a/headerwriter/writer_test.go b/server/responsewriter/writer_test.go similarity index 77% rename from headerwriter/writer_test.go rename to server/responsewriter/writer_test.go index 9893d2ce..8f8e662f 100644 --- a/headerwriter/writer_test.go +++ b/server/responsewriter/writer_test.go @@ -1,4 +1,4 @@ -package headerwriter +package responsewriter import ( "fmt" @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/suite" ) -type HeaderWriterSuite struct { +type ResponseWriterSuite struct { suite.Suite } @@ -22,16 +22,18 @@ type writerTestCase struct { req http.Header res http.Header config Config - fn func(*Request) + fn func(*Writer) } -func (s *HeaderWriterSuite) TestHeaderCases() { +func (s *ResponseWriterSuite) TestHeaderCases() { expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC) expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds())) shortExpires := time.Now().Add(10 * time.Second) shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds())) + writeResponseTimeout := 10 * time.Second + tt := []writerTestCase{ { name: "MinimalHeaders", @@ -44,8 +46,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { SetCanonicalHeader: false, DefaultTTL: 0, CacheControlPassthrough: false, - EnableClientHints: false, - SetVaryAccept: false, + WriteResponseTimeout: writeResponseTimeout, }, }, { @@ -60,6 +61,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ CacheControlPassthrough: true, DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, }, { @@ -74,6 +76,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ CacheControlPassthrough: true, DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, }, { @@ -88,6 +91,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ CacheControlPassthrough: true, DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, }, { @@ -99,10 +103,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - SetCanonicalHeader: true, - DefaultTTL: 3600, + SetCanonicalHeader: true, + DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetCanonical("https://example.com/image.jpg") }, }, @@ -114,8 +119,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - SetCanonicalHeader: true, - DefaultTTL: 3600, + SetCanonicalHeader: true, + DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, }, { @@ -126,10 +132,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - SetCanonicalHeader: false, - DefaultTTL: 3600, + SetCanonicalHeader: false, + DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetCanonical("https://example.com/image.jpg") }, }, @@ -141,10 +148,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - DefaultTTL: 3600, - FallbackImageTTL: 1, + DefaultTTL: 3600, + FallbackImageTTL: 1, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetIsFallbackImage() }, }, @@ -156,9 +164,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - DefaultTTL: math.MaxInt32, + DefaultTTL: math.MaxInt32, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetExpires(&expires) }, }, @@ -170,10 +179,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - DefaultTTL: math.MaxInt32, - FallbackImageTTL: 600, + DefaultTTL: math.MaxInt32, + FallbackImageTTL: 600, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetIsFallbackImage() w.SetExpires(&shortExpires) }, @@ -187,10 +197,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - EnableClientHints: true, - SetVaryAccept: true, + VaryValue: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width", + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetVary() }, }, @@ -204,8 +214,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.CacheControl: []string{"no-cache"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, - config: Config{}, - fn: func(w *Request) { + config: Config{ + WriteResponseTimeout: writeResponseTimeout, + }, + fn: func(w *Writer) { w.Passthrough("X-Test") }, }, @@ -217,8 +229,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.CacheControl: []string{"no-cache"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, - config: Config{}, - fn: func(w *Request) { + config: Config{ + WriteResponseTimeout: writeResponseTimeout, + }, + fn: func(w *Writer) { h := http.Header{} h.Set("X-From", "baz") w.CopyFrom(h, []string{"X-From"}) @@ -232,8 +246,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.CacheControl: []string{"no-cache"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, - config: Config{}, - fn: func(w *Request) { + config: Config{ + WriteResponseTimeout: writeResponseTimeout, + }, + fn: func(w *Writer) { w.SetContentLength(123) }, }, @@ -245,8 +261,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.CacheControl: []string{"no-cache"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, - config: Config{}, - fn: func(w *Request) { + config: Config{ + WriteResponseTimeout: writeResponseTimeout, + }, + fn: func(w *Writer) { w.SetContentType("image/png") }, }, @@ -258,58 +276,30 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{ - DefaultTTL: 3600, + DefaultTTL: 3600, + WriteResponseTimeout: writeResponseTimeout, }, - fn: func(w *Request) { + fn: func(w *Writer) { w.SetExpires(nil) }, }, - { - name: "WriteVaryAcceptOnly", - req: http.Header{}, - res: http.Header{ - httpheaders.Vary: []string{"Accept"}, - httpheaders.CacheControl: []string{"no-cache"}, - httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, - }, - config: Config{ - SetVaryAccept: true, - }, - fn: func(w *Request) { - w.SetVary() - }, - }, - { - name: "WriteVaryClientHintsOnly", - req: http.Header{}, - res: http.Header{ - httpheaders.Vary: []string{"Sec-CH-DPR, DPR, Sec-CH-Width, Width"}, - httpheaders.CacheControl: []string{"no-cache"}, - httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, - }, - config: Config{ - EnableClientHints: true, - }, - fn: func(w *Request) { - w.SetVary() - }, - }, } for _, tc := range tt { s.Run(tc.name, func() { - factory, err := New(&tc.config) + factory, err := NewFactory(&tc.config) s.Require().NoError(err) - writer := factory.NewRequest() + r := httptest.NewRecorder() + + writer := factory.NewWriter(r) writer.SetOriginHeaders(tc.req) if tc.fn != nil { tc.fn(writer) } - r := httptest.NewRecorder() - writer.Write(r) + writer.WriteHeader(http.StatusOK) s.Require().Equal(tc.res, r.Header()) }) @@ -317,5 +307,5 @@ func (s *HeaderWriterSuite) TestHeaderCases() { } func TestHeaderWriter(t *testing.T) { - suite.Run(t, new(HeaderWriterSuite)) + suite.Run(t, new(ResponseWriterSuite)) } diff --git a/server/router.go b/server/router.go index 4349bbeb..8f0b99dc 100644 --- a/server/router.go +++ b/server/router.go @@ -11,6 +11,7 @@ import ( nanoid "github.com/matoous/go-nanoid/v2" "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/server/responsewriter" ) const ( @@ -23,8 +24,10 @@ var ( requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`) ) +type ResponseWriter = *responsewriter.Writer + // RouteHandler is a function that handles HTTP requests. -type RouteHandler func(string, http.ResponseWriter, *http.Request) error +type RouteHandler func(string, ResponseWriter, *http.Request) error // Middleware is a function that wraps a RouteHandler with additional functionality. type Middleware func(next RouteHandler) RouteHandler @@ -40,6 +43,9 @@ type route struct { // Router is responsible for routing HTTP requests type Router struct { + // Response writers factory + rwFactory *responsewriter.Factory + // config represents the server configuration config *Config @@ -53,7 +59,15 @@ func NewRouter(config *Config) (*Router, error) { return nil, err } - return &Router{config: config}, nil + rwf, err := responsewriter.NewFactory(&config.ResponseWriter) + if err != nil { + return nil, err + } + + return &Router{ + rwFactory: rwf, + config: config, + }, nil } // add adds an abitary route to the router @@ -114,8 +128,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { req, timeoutCancel := startRequestTimer(req) defer timeoutCancel() - // Create the response writer which times out on write - rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout) + // Create the [ResponseWriter] + rww := r.rwFactory.NewWriter(rw) // Get/create request ID reqID := r.getRequestID(req) @@ -123,8 +137,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Replace request IP from headers r.replaceRemoteAddr(req) - rw.Header().Set(httpheaders.Server, defaultServerName) - rw.Header().Set(httpheaders.XRequestID, reqID) + rww.Header().Set(httpheaders.Server, defaultServerName) + rww.Header().Set(httpheaders.XRequestID, reqID) for _, rr := range r.routes { if !rr.isMatch(req) { @@ -138,18 +152,18 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { LogRequest(reqID, req) } - rr.handler(reqID, rw, req) + rr.handler(reqID, rww, req) return } // Means that we have not found matching route LogRequest(reqID, req) LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path)) - r.NotFoundHandler(reqID, rw, req) + r.NotFoundHandler(reqID, rww, req) } // NotFoundHandler is default 404 handler -func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error { +func (r *Router) NotFoundHandler(reqID string, rw ResponseWriter, req *http.Request) error { rw.Header().Set(httpheaders.ContentType, "text/plain") rw.WriteHeader(http.StatusNotFound) rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy @@ -158,7 +172,7 @@ func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http } // OkHandler is a default 200 OK handler -func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error { +func (r *Router) OkHandler(reqID string, rw ResponseWriter, req *http.Request) error { rw.Header().Set(httpheaders.ContentType, "text/plain") rw.WriteHeader(http.StatusOK) rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy diff --git a/server/router_test.go b/server/router_test.go index ea8e5551..6178d125 100644 --- a/server/router_test.go +++ b/server/router_test.go @@ -30,7 +30,7 @@ func (s *RouterTestSuite) TestHTTPMethods() { var capturedMethod string var capturedPath string - getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error { capturedMethod = req.Method capturedPath = req.URL.Path rw.WriteHeader(200) @@ -38,7 +38,7 @@ func (s *RouterTestSuite) TestHTTPMethods() { return nil } - optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error { capturedMethod = req.Method capturedPath = req.URL.Path rw.WriteHeader(200) @@ -46,7 +46,7 @@ func (s *RouterTestSuite) TestHTTPMethods() { return nil } - headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error { capturedMethod = req.Method capturedPath = req.URL.Path rw.WriteHeader(200) @@ -114,20 +114,20 @@ func (s *RouterTestSuite) TestMiddlewareOrder() { var order []string middleware1 := func(next RouteHandler) RouteHandler { - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return func(reqID string, rw ResponseWriter, req *http.Request) error { order = append(order, "middleware1") return next(reqID, rw, req) } } middleware2 := func(next RouteHandler) RouteHandler { - return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return func(reqID string, rw ResponseWriter, req *http.Request) error { order = append(order, "middleware2") return next(reqID, rw, req) } } - handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + handler := func(reqID string, rw ResponseWriter, req *http.Request) error { order = append(order, "handler") rw.WriteHeader(200) return nil @@ -146,7 +146,7 @@ func (s *RouterTestSuite) TestMiddlewareOrder() { // TestServeHTTP tests ServeHTTP method func (s *RouterTestSuite) TestServeHTTP() { - handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + handler := func(reqID string, rw ResponseWriter, req *http.Request) error { rw.Header().Set("Custom-Header", "test-value") rw.WriteHeader(200) rw.Write([]byte("success")) @@ -169,7 +169,7 @@ func (s *RouterTestSuite) TestServeHTTP() { // TestRequestID checks request ID generation and validation func (s *RouterTestSuite) TestRequestID() { - handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + handler := func(reqID string, rw ResponseWriter, req *http.Request) error { rw.WriteHeader(200) return nil } @@ -209,7 +209,7 @@ func (s *RouterTestSuite) TestRequestID() { // TestLambdaRequestIDExtraction checks AWS lambda request id extraction func (s *RouterTestSuite) TestLambdaRequestIDExtraction() { - handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + handler := func(reqID string, rw ResponseWriter, req *http.Request) error { rw.WriteHeader(200) return nil } @@ -229,7 +229,7 @@ func (s *RouterTestSuite) TestLambdaRequestIDExtraction() { // Test IP address handling func (s *RouterTestSuite) TestReplaceIP() { var capturedRemoteAddr string - handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + handler := func(reqID string, rw ResponseWriter, req *http.Request) error { capturedRemoteAddr = req.RemoteAddr rw.WriteHeader(200) return nil @@ -298,7 +298,7 @@ func (s *RouterTestSuite) TestReplaceIP() { // TestRouteOrder checks exact/non-exact insertion order func (s *RouterTestSuite) TestRouteOrder() { - h := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + h := func(reqID string, rw ResponseWriter, req *http.Request) error { return nil } diff --git a/server/server_test.go b/server/server_test.go index fb0b0fed..3fddcce0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -29,10 +29,14 @@ func (s *ServerTestSuite) SetupTest() { s.blankRouter = r } -func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { +func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error { return nil } +func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter { + return s.blankRouter.rwFactory.NewWriter(rw) +} + func (s *ServerTestSuite) TestStartServerWithInvalidBind() { ctx, cancel := context.WithCancel(s.T().Context()) @@ -121,7 +125,7 @@ func (s *ServerTestSuite) TestWithCORS() { req := httptest.NewRequest("GET", "/test", nil) rw := httptest.NewRecorder() - wrappedHandler("test-req-id", rw, req) + wrappedHandler("test-req-id", s.wrapRW(rw), req) s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin)) s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods)) @@ -170,7 +174,7 @@ func (s *ServerTestSuite) TestWithSecret() { } rw := httptest.NewRecorder() - err = wrappedHandler("test-req-id", rw, req) + err = wrappedHandler("test-req-id", s.wrapRW(rw), req) if tt.expectError { s.Require().Error(err) @@ -182,7 +186,7 @@ func (s *ServerTestSuite) TestWithSecret() { } func (s *ServerTestSuite) TestIntoSuccess() { - mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error { rw.WriteHeader(http.StatusOK) return nil } @@ -192,14 +196,14 @@ func (s *ServerTestSuite) TestIntoSuccess() { req := httptest.NewRequest("GET", "/test", nil) rw := httptest.NewRecorder() - wrappedHandler("test-req-id", rw, req) + wrappedHandler("test-req-id", s.wrapRW(rw), req) s.Equal(http.StatusOK, rw.Code) } func (s *ServerTestSuite) TestIntoWithError() { testError := errors.New("test error") - mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error { return testError } @@ -208,7 +212,7 @@ func (s *ServerTestSuite) TestIntoWithError() { req := httptest.NewRequest("GET", "/test", nil) rw := httptest.NewRecorder() - wrappedHandler("test-req-id", rw, req) + wrappedHandler("test-req-id", s.wrapRW(rw), req) s.Equal(http.StatusInternalServerError, rw.Code) s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType)) @@ -216,7 +220,7 @@ func (s *ServerTestSuite) TestIntoWithError() { func (s *ServerTestSuite) TestIntoPanicWithError() { testError := errors.New("panic error") - mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error { panic(testError) } @@ -226,7 +230,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() { rw := httptest.NewRecorder() s.NotPanics(func() { - err := wrappedHandler("test-req-id", rw, req) + err := wrappedHandler("test-req-id", s.wrapRW(rw), req) s.Require().Error(err, "panic error") }) @@ -234,7 +238,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() { } func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() { - mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error { panic(http.ErrAbortHandler) } @@ -245,12 +249,12 @@ func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() { // Should re-panic with ErrAbortHandler s.Panics(func() { - wrappedHandler("test-req-id", rw, req) + wrappedHandler("test-req-id", s.wrapRW(rw), req) }) } func (s *ServerTestSuite) TestIntoPanicWithNonError() { - mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error { panic("string panic") } @@ -261,7 +265,7 @@ func (s *ServerTestSuite) TestIntoPanicWithNonError() { // Should re-panic with non-error panics s.NotPanics(func() { - err := wrappedHandler("test-req-id", rw, req) + err := wrappedHandler("test-req-id", s.wrapRW(rw), req) s.Require().Error(err, "string panic") }) } diff --git a/server/timeout_response.go b/server/timeout_response.go deleted file mode 100644 index 1af6c18c..00000000 --- a/server/timeout_response.go +++ /dev/null @@ -1,47 +0,0 @@ -package server - -import ( - "net/http" - "time" -) - -// timeoutResponse manages response writer with timeout. It has -// timeout on all write methods. -type timeoutResponse struct { - http.ResponseWriter - controller *http.ResponseController - timeout time.Duration -} - -// newTimeoutResponse creates a new timeoutResponse -func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter { - return &timeoutResponse{ - ResponseWriter: rw, - controller: http.NewResponseController(rw), - timeout: timeout, - } -} - -// Write implements http.ResponseWriter.Write -func (rw *timeoutResponse) Write(b []byte) (int, error) { - var ( - n int - err error - ) - rw.withWriteDeadline(func() { - n, err = rw.ResponseWriter.Write(b) - }) - return n, err -} - -// withWriteDeadline executes a Write* function with a deadline -func (rw *timeoutResponse) withWriteDeadline(f func()) { - deadline := time.Now().Add(rw.timeout) - - // Set write deadline - rw.controller.SetWriteDeadline(deadline) - - // Reset write deadline after method has finished - defer rw.controller.SetWriteDeadline(time.Time{}) - f() -}