mirror of
https://github.com/imgproxy/imgproxy.git
synced 2025-10-09 19:52:30 +02:00
Rebuild headerwriter to server.ResponseWriter
This commit is contained in:
committed by
Sergei Aleksandrovich
parent
53645688fb
commit
1f6d007948
@@ -6,7 +6,6 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||||
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
||||||
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/semaphores"
|
"github.com/imgproxy/imgproxy/v3/semaphores"
|
||||||
"github.com/imgproxy/imgproxy/v3/server"
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
)
|
)
|
||||||
@@ -19,7 +18,6 @@ type HandlerConfigs struct {
|
|||||||
|
|
||||||
// Config represents an instance configuration
|
// Config represents an instance configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
HeaderWriter headerwriter.Config
|
|
||||||
Semaphores semaphores.Config
|
Semaphores semaphores.Config
|
||||||
FallbackImage auximageprovider.StaticConfig
|
FallbackImage auximageprovider.StaticConfig
|
||||||
WatermarkImage auximageprovider.StaticConfig
|
WatermarkImage auximageprovider.StaticConfig
|
||||||
@@ -31,7 +29,6 @@ type Config struct {
|
|||||||
// NewDefaultConfig creates a new default configuration
|
// NewDefaultConfig creates a new default configuration
|
||||||
func NewDefaultConfig() Config {
|
func NewDefaultConfig() Config {
|
||||||
return Config{
|
return Config{
|
||||||
HeaderWriter: headerwriter.NewDefaultConfig(),
|
|
||||||
Semaphores: semaphores.NewDefaultConfig(),
|
Semaphores: semaphores.NewDefaultConfig(),
|
||||||
FallbackImage: auximageprovider.NewDefaultStaticConfig(),
|
FallbackImage: auximageprovider.NewDefaultStaticConfig(),
|
||||||
WatermarkImage: auximageprovider.NewDefaultStaticConfig(),
|
WatermarkImage: auximageprovider.NewDefaultStaticConfig(),
|
||||||
@@ -62,10 +59,6 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = headerwriter.LoadConfigFromEnv(&c.HeaderWriter); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil {
|
if _, err = semaphores.LoadConfigFromEnv(&c.Semaphores); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -22,7 +22,7 @@ func New() *Handler {
|
|||||||
// Execute handles the health request
|
// Execute handles the health request
|
||||||
func (h *Handler) Execute(
|
func (h *Handler) Execute(
|
||||||
reqID string,
|
reqID string,
|
||||||
rw http.ResponseWriter,
|
rw server.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) error {
|
) error {
|
||||||
var (
|
var (
|
||||||
|
@@ -6,11 +6,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthHandler(t *testing.T) {
|
func TestHealthHandler(t *testing.T) {
|
||||||
|
// Create responsewriter.Factory
|
||||||
|
rwConf := responsewriter.NewDefaultConfig()
|
||||||
|
rwf, err := responsewriter.NewFactory(&rwConf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create a ResponseRecorder to record the response
|
// Create a ResponseRecorder to record the response
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -18,7 +25,7 @@ func TestHealthHandler(t *testing.T) {
|
|||||||
h := New()
|
h := New()
|
||||||
|
|
||||||
// Call the handler function directly (no need for actual HTTP request)
|
// Call the handler function directly (no need for actual HTTP request)
|
||||||
h.Execute("test-req-id", rr, nil)
|
h.Execute("test-req-id", rwf.NewWriter(rr), nil)
|
||||||
|
|
||||||
// Check that we get a valid response (either 200 or 500 depending on vips state)
|
// Check that we get a valid response (either 200 or 500 depending on vips state)
|
||||||
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
|
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed body.html
|
//go:embed body.html
|
||||||
@@ -21,7 +22,7 @@ func New() *Handler {
|
|||||||
// Execute handles the landing request
|
// Execute handles the landing request
|
||||||
func (h *Handler) Execute(
|
func (h *Handler) Execute(
|
||||||
reqID string,
|
reqID string,
|
||||||
rw http.ResponseWriter,
|
rw server.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) error {
|
) error {
|
||||||
rw.Header().Set(httpheaders.ContentType, "text/html")
|
rw.Header().Set(httpheaders.ContentType, "text/html")
|
||||||
|
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||||
"github.com/imgproxy/imgproxy/v3/handlers"
|
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||||
"github.com/imgproxy/imgproxy/v3/handlers/stream"
|
"github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||||
@@ -17,11 +16,11 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/security"
|
"github.com/imgproxy/imgproxy/v3/security"
|
||||||
"github.com/imgproxy/imgproxy/v3/semaphores"
|
"github.com/imgproxy/imgproxy/v3/semaphores"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandlerContext provides access to shared handler dependencies
|
// HandlerContext provides access to shared handler dependencies
|
||||||
type HandlerContext interface {
|
type HandlerContext interface {
|
||||||
HeaderWriter() *headerwriter.Writer
|
|
||||||
Semaphores() *semaphores.Semaphores
|
Semaphores() *semaphores.Semaphores
|
||||||
FallbackImage() auximageprovider.Provider
|
FallbackImage() auximageprovider.Provider
|
||||||
WatermarkImage() auximageprovider.Provider
|
WatermarkImage() auximageprovider.Provider
|
||||||
@@ -56,7 +55,7 @@ func New(
|
|||||||
// Execute handles the image processing request
|
// Execute handles the image processing request
|
||||||
func (h *Handler) Execute(
|
func (h *Handler) Execute(
|
||||||
reqID string,
|
reqID string,
|
||||||
rw http.ResponseWriter,
|
rw server.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) error {
|
) error {
|
||||||
// Increment the number of requests in progress
|
// Increment the number of requests in progress
|
||||||
@@ -86,7 +85,6 @@ func (h *Handler) Execute(
|
|||||||
po: po,
|
po: po,
|
||||||
imageURL: imageURL,
|
imageURL: imageURL,
|
||||||
monitoringMeta: mm,
|
monitoringMeta: mm,
|
||||||
hwr: h.HeaderWriter().NewRequest(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return hReq.execute(ctx)
|
return hReq.execute(ctx)
|
||||||
|
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||||
"github.com/imgproxy/imgproxy/v3/handlers"
|
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||||
@@ -23,12 +22,11 @@ type request struct {
|
|||||||
|
|
||||||
reqID string
|
reqID string
|
||||||
req *http.Request
|
req *http.Request
|
||||||
rw http.ResponseWriter
|
rw server.ResponseWriter
|
||||||
config *Config
|
config *Config
|
||||||
po *options.ProcessingOptions
|
po *options.ProcessingOptions
|
||||||
imageURL string
|
imageURL string
|
||||||
monitoringMeta monitoring.Meta
|
monitoringMeta monitoring.Meta
|
||||||
hwr *headerwriter.Request
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute handles the actual processing logic
|
// execute handles the actual processing logic
|
||||||
@@ -84,13 +82,13 @@ func (r *request) execute(ctx context.Context) error {
|
|||||||
var nmErr fetcher.NotModifiedError
|
var nmErr fetcher.NotModifiedError
|
||||||
|
|
||||||
if errors.As(err, &nmErr) {
|
if errors.As(err, &nmErr) {
|
||||||
r.hwr.SetOriginHeaders(nmErr.Headers())
|
r.rw.SetOriginHeaders(nmErr.Headers())
|
||||||
|
|
||||||
return r.respondWithNotModified()
|
return r.respondWithNotModified()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare to write image response headers
|
// Prepare to write image response headers
|
||||||
r.hwr.SetOriginHeaders(originHeaders)
|
r.rw.SetOriginHeaders(originHeaders)
|
||||||
|
|
||||||
// If error is not related to NotModified, respond with fallback image and replace image data
|
// If error is not related to NotModified, respond with fallback image and replace image data
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,7 +121,7 @@ func (r *request) execute(ctx context.Context) error {
|
|||||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing))
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryProcessing))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write debug headers. It seems unlogical to move they to headerwriter since they're
|
// Write debug headers. It seems unlogical to move they to responsewriter since they're
|
||||||
// not used anywhere else.
|
// not used anywhere else.
|
||||||
err = r.writeDebugHeaders(result, originData)
|
err = r.writeDebugHeaders(result, originData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -123,8 +123,8 @@ func (r *request) handleDownloadError(
|
|||||||
headers.Del(httpheaders.Expires)
|
headers.Del(httpheaders.Expires)
|
||||||
headers.Del(httpheaders.LastModified)
|
headers.Del(httpheaders.LastModified)
|
||||||
|
|
||||||
r.hwr.SetOriginHeaders(headers)
|
r.rw.SetOriginHeaders(headers)
|
||||||
r.hwr.SetIsFallbackImage()
|
r.rw.SetIsFallbackImage()
|
||||||
|
|
||||||
return data, statusCode, nil
|
return data, statusCode, nil
|
||||||
}
|
}
|
||||||
@@ -186,19 +186,17 @@ func (r *request) writeDebugHeaders(result *processing.Result, originData imaged
|
|||||||
|
|
||||||
// respondWithNotModified writes not-modified response
|
// respondWithNotModified writes not-modified response
|
||||||
func (r *request) respondWithNotModified() error {
|
func (r *request) respondWithNotModified() error {
|
||||||
r.hwr.SetExpires(r.po.Expires)
|
r.rw.SetExpires(r.po.Expires)
|
||||||
r.hwr.SetVary()
|
r.rw.SetVary()
|
||||||
|
|
||||||
if r.config.LastModifiedEnabled {
|
if r.config.LastModifiedEnabled {
|
||||||
r.hwr.Passthrough(httpheaders.LastModified)
|
r.rw.Passthrough(httpheaders.LastModified)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.config.ETagEnabled {
|
if r.config.ETagEnabled {
|
||||||
r.hwr.Passthrough(httpheaders.Etag)
|
r.rw.Passthrough(httpheaders.Etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.hwr.Write(r.rw)
|
|
||||||
|
|
||||||
r.rw.WriteHeader(http.StatusNotModified)
|
r.rw.WriteHeader(http.StatusNotModified)
|
||||||
|
|
||||||
server.LogResponse(
|
server.LogResponse(
|
||||||
@@ -221,29 +219,27 @@ func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageDat
|
|||||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize))
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(handlers.CategoryImageDataSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.hwr.SetContentType(resultData.Format().Mime())
|
r.rw.SetContentType(resultData.Format().Mime())
|
||||||
r.hwr.SetContentLength(resultSize)
|
r.rw.SetContentLength(resultSize)
|
||||||
r.hwr.SetContentDisposition(
|
r.rw.SetContentDisposition(
|
||||||
r.imageURL,
|
r.imageURL,
|
||||||
r.po.Filename,
|
r.po.Filename,
|
||||||
resultData.Format().Ext(),
|
resultData.Format().Ext(),
|
||||||
"",
|
"",
|
||||||
r.po.ReturnAttachment,
|
r.po.ReturnAttachment,
|
||||||
)
|
)
|
||||||
r.hwr.SetExpires(r.po.Expires)
|
r.rw.SetExpires(r.po.Expires)
|
||||||
r.hwr.SetVary()
|
r.rw.SetVary()
|
||||||
r.hwr.SetCanonical(r.imageURL)
|
r.rw.SetCanonical(r.imageURL)
|
||||||
|
|
||||||
if r.config.LastModifiedEnabled {
|
if r.config.LastModifiedEnabled {
|
||||||
r.hwr.Passthrough(httpheaders.LastModified)
|
r.rw.Passthrough(httpheaders.LastModified)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.config.ETagEnabled {
|
if r.config.ETagEnabled {
|
||||||
r.hwr.Passthrough(httpheaders.Etag)
|
r.rw.Passthrough(httpheaders.Etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.hwr.Write(r.rw)
|
|
||||||
|
|
||||||
r.rw.WriteHeader(statusCode)
|
r.rw.WriteHeader(statusCode)
|
||||||
|
|
||||||
_, err = io.Copy(r.rw, resultData.Reader())
|
_, err = io.Copy(r.rw, resultData.Reader())
|
||||||
|
@@ -6,16 +6,16 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/cookies"
|
"github.com/imgproxy/imgproxy/v3/cookies"
|
||||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
"github.com/imgproxy/imgproxy/v3/monitoring"
|
"github.com/imgproxy/imgproxy/v3/monitoring"
|
||||||
"github.com/imgproxy/imgproxy/v3/monitoring/stats"
|
"github.com/imgproxy/imgproxy/v3/monitoring/stats"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/server"
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -35,9 +35,8 @@ var (
|
|||||||
|
|
||||||
// Handler handles image passthrough requests, allowing images to be streamed directly
|
// Handler handles image passthrough requests, allowing images to be streamed directly
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
config *Config // Configuration for the streamer
|
config *Config // Configuration for the streamer
|
||||||
fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching
|
fetcher *fetcher.Fetcher // Fetcher instance to handle image fetching
|
||||||
hw *headerwriter.Writer // Configured HeaderWriter instance
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// request holds the parameters and state for a single streaming request
|
// request holds the parameters and state for a single streaming request
|
||||||
@@ -47,12 +46,11 @@ type request struct {
|
|||||||
imageURL string
|
imageURL string
|
||||||
reqID string
|
reqID string
|
||||||
po *options.ProcessingOptions
|
po *options.ProcessingOptions
|
||||||
rw http.ResponseWriter
|
rw server.ResponseWriter
|
||||||
hw *headerwriter.Request
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates new handler object
|
// New creates new handler object
|
||||||
func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Handler, error) {
|
func New(config *Config, fetcher *fetcher.Fetcher) (*Handler, error) {
|
||||||
if err := config.Validate(); err != nil {
|
if err := config.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -60,7 +58,6 @@ func New(config *Config, hw *headerwriter.Writer, fetcher *fetcher.Fetcher) (*Ha
|
|||||||
return &Handler{
|
return &Handler{
|
||||||
fetcher: fetcher,
|
fetcher: fetcher,
|
||||||
config: config,
|
config: config,
|
||||||
hw: hw,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +68,7 @@ func (s *Handler) Execute(
|
|||||||
imageURL string,
|
imageURL string,
|
||||||
reqID string,
|
reqID string,
|
||||||
po *options.ProcessingOptions,
|
po *options.ProcessingOptions,
|
||||||
rw http.ResponseWriter,
|
rw server.ResponseWriter,
|
||||||
) error {
|
) error {
|
||||||
stream := &request{
|
stream := &request{
|
||||||
handler: s,
|
handler: s,
|
||||||
@@ -80,7 +77,6 @@ func (s *Handler) Execute(
|
|||||||
reqID: reqID,
|
reqID: reqID,
|
||||||
po: po,
|
po: po,
|
||||||
rw: rw,
|
rw: rw,
|
||||||
hw: s.hw.NewRequest(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return stream.execute(ctx)
|
return stream.execute(ctx)
|
||||||
@@ -118,17 +114,14 @@ func (s *request) execute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Output streaming response headers
|
// Output streaming response headers
|
||||||
s.hw.SetOriginHeaders(res.Header)
|
s.rw.SetOriginHeaders(res.Header)
|
||||||
s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
|
s.rw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was
|
||||||
s.hw.SetContentLength(int(res.ContentLength))
|
s.rw.SetContentLength(int(res.ContentLength))
|
||||||
s.hw.SetCanonical(s.imageURL)
|
s.rw.SetCanonical(s.imageURL)
|
||||||
s.hw.SetExpires(s.po.Expires)
|
s.rw.SetExpires(s.po.Expires)
|
||||||
|
|
||||||
// Set the Content-Disposition header
|
// Set the Content-Disposition header
|
||||||
s.setContentDisposition(r.URL().Path, res, s.hw)
|
s.setContentDisposition(r.URL().Path, res)
|
||||||
|
|
||||||
// Write headers from writer
|
|
||||||
s.hw.Write(s.rw)
|
|
||||||
|
|
||||||
// Copy the status code from the original response
|
// Copy the status code from the original response
|
||||||
s.rw.WriteHeader(res.StatusCode)
|
s.rw.WriteHeader(res.StatusCode)
|
||||||
@@ -158,7 +151,7 @@ func (s *request) getImageRequestHeaders() http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setContentDisposition writes the headers to the response writer
|
// setContentDisposition writes the headers to the response writer
|
||||||
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response, hw *headerwriter.Request) {
|
func (s *request) setContentDisposition(imagePath string, serverResponse *http.Response) {
|
||||||
// Try to set correct Content-Disposition file name and extension
|
// Try to set correct Content-Disposition file name and extension
|
||||||
if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 {
|
if serverResponse.StatusCode < 200 || serverResponse.StatusCode >= 300 {
|
||||||
return
|
return
|
||||||
@@ -166,7 +159,7 @@ func (s *request) setContentDisposition(imagePath string, serverResponse *http.R
|
|||||||
|
|
||||||
ct := serverResponse.Header.Get(httpheaders.ContentType)
|
ct := serverResponse.Header.Get(httpheaders.ContentType)
|
||||||
|
|
||||||
hw.SetContentDisposition(
|
s.rw.SetContentDisposition(
|
||||||
imagePath,
|
imagePath,
|
||||||
s.po.Filename,
|
s.po.Filename,
|
||||||
"",
|
"",
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -17,9 +16,10 @@ import (
|
|||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
"github.com/imgproxy/imgproxy/v3/fetcher"
|
"github.com/imgproxy/imgproxy/v3/fetcher"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -27,14 +27,54 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type HandlerTestSuite struct {
|
type HandlerTestSuite struct {
|
||||||
suite.Suite
|
testutil.LazySuite
|
||||||
handler *Handler
|
|
||||||
|
rwConf testutil.LazyObj[*responsewriter.Config]
|
||||||
|
rwFactory testutil.LazyObj[*responsewriter.Factory]
|
||||||
|
|
||||||
|
config testutil.LazyObj[*Config]
|
||||||
|
handler testutil.LazyObj[*Handler]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HandlerTestSuite) SetupSuite() {
|
func (s *HandlerTestSuite) SetupSuite() {
|
||||||
config.Reset()
|
config.Reset()
|
||||||
config.AllowLoopbackSourceAddresses = true
|
config.AllowLoopbackSourceAddresses = true
|
||||||
|
|
||||||
|
s.rwConf, _ = testutil.NewLazySuiteObj(
|
||||||
|
s,
|
||||||
|
func() (*responsewriter.Config, error) {
|
||||||
|
c := responsewriter.NewDefaultConfig()
|
||||||
|
return &c, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
s.rwFactory, _ = testutil.NewLazySuiteObj(
|
||||||
|
s,
|
||||||
|
func() (*responsewriter.Factory, error) {
|
||||||
|
return responsewriter.NewFactory(s.rwConf())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
s.config, _ = testutil.NewLazySuiteObj(
|
||||||
|
s,
|
||||||
|
func() (*Config, error) {
|
||||||
|
c := NewDefaultConfig()
|
||||||
|
return &c, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
s.handler, _ = testutil.NewLazySuiteObj(
|
||||||
|
s,
|
||||||
|
func() (*Handler, error) {
|
||||||
|
fc := fetcher.NewDefaultConfig()
|
||||||
|
|
||||||
|
fetcher, err := fetcher.New(&fc)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
return New(s.config(), fetcher)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
// Silence logs during tests
|
// Silence logs during tests
|
||||||
logrus.SetOutput(io.Discard)
|
logrus.SetOutput(io.Discard)
|
||||||
}
|
}
|
||||||
@@ -47,21 +87,11 @@ func (s *HandlerTestSuite) TearDownSuite() {
|
|||||||
func (s *HandlerTestSuite) SetupTest() {
|
func (s *HandlerTestSuite) SetupTest() {
|
||||||
config.Reset()
|
config.Reset()
|
||||||
config.AllowLoopbackSourceAddresses = true
|
config.AllowLoopbackSourceAddresses = true
|
||||||
|
}
|
||||||
|
|
||||||
fc := fetcher.NewDefaultConfig()
|
func (s *HandlerTestSuite) SetupSubTest() {
|
||||||
|
// We use t.Run() a lot, so we need to reset lazy objects at the beginning of each subtest
|
||||||
fetcher, err := fetcher.New(&fc)
|
s.ResetLazyObjects()
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
cfg := NewDefaultConfig()
|
|
||||||
|
|
||||||
hwc := headerwriter.NewDefaultConfig()
|
|
||||||
hw, err := headerwriter.New(&hwc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
h, err := New(&cfg, hw, fetcher)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
s.handler = h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
||||||
@@ -70,6 +100,24 @@ func (s *HandlerTestSuite) readTestFile(name string) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HandlerTestSuite) execute(
|
||||||
|
imageURL string,
|
||||||
|
header http.Header,
|
||||||
|
po *options.ProcessingOptions,
|
||||||
|
) *httptest.ResponseRecorder {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
httpheaders.CopyAll(header, req.Header, true)
|
||||||
|
|
||||||
|
ctx := s.T().Context()
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
rww := s.rwFactory().NewWriter(rw)
|
||||||
|
|
||||||
|
err := s.handler().Execute(ctx, req, imageURL, "test-req-id", po, rww)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
// TestHandlerBasicRequest checks basic streaming request
|
// TestHandlerBasicRequest checks basic streaming request
|
||||||
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
||||||
data := s.readTestFile("test1.png")
|
data := s.readTestFile("test1.png")
|
||||||
@@ -81,12 +129,7 @@ func (s *HandlerTestSuite) TestHandlerBasicRequest() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err := s.handler.Execute(context.Background(), req, ts.URL, "request-1", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -114,12 +157,7 @@ func (s *HandlerTestSuite) TestHandlerResponseHeadersPassthrough() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -148,16 +186,12 @@ func (s *HandlerTestSuite) TestHandlerRequestHeadersPassthrough() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
h := make(http.Header)
|
||||||
req.Header.Set(httpheaders.IfNoneMatch, etag)
|
h.Set(httpheaders.IfNoneMatch, etag)
|
||||||
req.Header.Set(httpheaders.AcceptEncoding, "gzip")
|
h.Set(httpheaders.AcceptEncoding, "gzip")
|
||||||
req.Header.Set(httpheaders.Range, "bytes=*")
|
h.Set(httpheaders.Range, "bytes=*")
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -175,8 +209,6 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{
|
po := &options.ProcessingOptions{
|
||||||
Filename: "custom_name",
|
Filename: "custom_name",
|
||||||
ReturnAttachment: true,
|
ReturnAttachment: true,
|
||||||
@@ -184,8 +216,7 @@ func (s *HandlerTestSuite) TestHandlerContentDisposition() {
|
|||||||
|
|
||||||
// Use a URL with a .png extension to help content disposition logic
|
// Use a URL with a .png extension to help content disposition logic
|
||||||
imageURL := ts.URL + "/test.png"
|
imageURL := ts.URL + "/test.png"
|
||||||
err := s.handler.Execute(context.Background(), req, imageURL, "test-req-id", po, rw)
|
rw := s.execute(imageURL, nil, po)
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -342,25 +373,9 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
s.rwConf().CacheControlPassthrough = tc.cacheControlPassthrough
|
||||||
s.Require().NoError(err)
|
s.rwConf().DefaultTTL = 4242
|
||||||
|
|
||||||
fetcher, err := fetcher.New(fc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
cfg := NewDefaultConfig()
|
|
||||||
hwc := headerwriter.NewDefaultConfig()
|
|
||||||
hwc.CacheControlPassthrough = tc.cacheControlPassthrough
|
|
||||||
hwc.DefaultTTL = 4242
|
|
||||||
|
|
||||||
hw, err := headerwriter.New(&hwc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
handler, err := New(&cfg, hw, fetcher)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
po := &options.ProcessingOptions{}
|
||||||
|
|
||||||
if tc.timestampOffset != nil {
|
if tc.timestampOffset != nil {
|
||||||
@@ -368,8 +383,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
|
|||||||
po.Expires = &expires
|
po.Expires = &expires
|
||||||
}
|
}
|
||||||
|
|
||||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
rw := s.execute(ts.URL, nil, po)
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
|
s.Require().Equal(tc.expectedStatusCode, res.StatusCode)
|
||||||
@@ -400,12 +414,7 @@ func (s *HandlerTestSuite) TestHandlerSecurityHeaders() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -420,12 +429,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err := s.handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(404, res.StatusCode)
|
s.Require().Equal(404, res.StatusCode)
|
||||||
@@ -433,21 +437,7 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
|
|||||||
|
|
||||||
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
|
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
|
||||||
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
||||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
s.config().CookiePassthrough = true
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
fetcher, err := fetcher.New(fc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
cfg := NewDefaultConfig()
|
|
||||||
cfg.CookiePassthrough = true
|
|
||||||
|
|
||||||
hwc := headerwriter.NewDefaultConfig()
|
|
||||||
hw, err := headerwriter.New(&hwc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
handler, err := New(&cfg, hw, fetcher)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
data := s.readTestFile("test1.png")
|
data := s.readTestFile("test1.png")
|
||||||
|
|
||||||
@@ -464,13 +454,10 @@ func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
h := make(http.Header)
|
||||||
req.Header.Set(httpheaders.Cookie, "test_cookie=test_value")
|
h.Set(httpheaders.Cookie, "test_cookie=test_value")
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
rw := s.execute(ts.URL, h, &options.ProcessingOptions{})
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
@@ -488,29 +475,9 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
|
|||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
for _, sc := range []bool{true, false} {
|
for _, sc := range []bool{true, false} {
|
||||||
fc, err := fetcher.LoadConfigFromEnv(nil)
|
s.rwConf().SetCanonicalHeader = sc
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
fetcher, err := fetcher.New(fc)
|
rw := s.execute(ts.URL, nil, &options.ProcessingOptions{})
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
cfg := NewDefaultConfig()
|
|
||||||
hwc := headerwriter.NewDefaultConfig()
|
|
||||||
|
|
||||||
hwc.SetCanonicalHeader = sc
|
|
||||||
|
|
||||||
hw, err := headerwriter.New(&hwc)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
handler, err := New(&cfg, hw, fetcher)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
po := &options.ProcessingOptions{}
|
|
||||||
|
|
||||||
err = handler.Execute(context.Background(), req, ts.URL, "test-req-id", po, rw)
|
|
||||||
s.Require().NoError(err)
|
|
||||||
|
|
||||||
res := rw.Result()
|
res := rw.Result()
|
||||||
s.Require().Equal(200, res.StatusCode)
|
s.Require().Equal(200, res.StatusCode)
|
||||||
|
@@ -1,62 +0,0 @@
|
|||||||
package headerwriter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/ensure"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config is the package-local configuration
|
|
||||||
type Config struct {
|
|
||||||
SetCanonicalHeader bool // Indicates whether to set the canonical header
|
|
||||||
DefaultTTL int // Default Cache-Control max-age= value for cached images
|
|
||||||
FallbackImageTTL int // TTL for images served as fallbacks
|
|
||||||
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
|
|
||||||
EnableClientHints bool // Enable Vary header
|
|
||||||
SetVaryAccept bool // Whether to include Accept in Vary header
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultConfig returns a new Config instance with default values.
|
|
||||||
func NewDefaultConfig() Config {
|
|
||||||
return Config{
|
|
||||||
SetCanonicalHeader: false,
|
|
||||||
DefaultTTL: 31536000,
|
|
||||||
FallbackImageTTL: 0,
|
|
||||||
CacheControlPassthrough: false,
|
|
||||||
EnableClientHints: false,
|
|
||||||
SetVaryAccept: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadConfigFromEnv overrides configuration variables from environment
|
|
||||||
func LoadConfigFromEnv(c *Config) (*Config, error) {
|
|
||||||
c = ensure.Ensure(c, NewDefaultConfig)
|
|
||||||
|
|
||||||
c.SetCanonicalHeader = config.SetCanonicalHeader
|
|
||||||
c.DefaultTTL = config.TTL
|
|
||||||
c.FallbackImageTTL = config.FallbackImageTTL
|
|
||||||
c.CacheControlPassthrough = config.CacheControlPassthrough
|
|
||||||
c.EnableClientHints = config.EnableClientHints
|
|
||||||
c.SetVaryAccept = config.AutoWebp ||
|
|
||||||
config.EnforceWebp ||
|
|
||||||
config.AutoAvif ||
|
|
||||||
config.EnforceAvif ||
|
|
||||||
config.AutoJxl ||
|
|
||||||
config.EnforceJxl
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate checks config for errors
|
|
||||||
func (c *Config) Validate() error {
|
|
||||||
if c.DefaultTTL < 0 {
|
|
||||||
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.FallbackImageTTL < 0 {
|
|
||||||
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@@ -1,214 +0,0 @@
|
|||||||
// headerwriter is responsible for writing processing/stream response headers
|
|
||||||
package headerwriter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Writer is a struct that creates header writer factories.
|
|
||||||
type Writer struct {
|
|
||||||
config *Config
|
|
||||||
varyValue string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request is a private struct that builds HTTP response headers for a specific request.
|
|
||||||
type Request struct {
|
|
||||||
writer *Writer
|
|
||||||
originHeaders http.Header // Original response headers
|
|
||||||
result http.Header // Headers to be written to the response
|
|
||||||
maxAge int // Current max age for Cache-Control header
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new header writer factory with the provided config.
|
|
||||||
func New(config *Config) (*Writer, error) {
|
|
||||||
if err := config.Validate(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
vary := make([]string, 0)
|
|
||||||
|
|
||||||
if config.SetVaryAccept {
|
|
||||||
vary = append(vary, "Accept")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.EnableClientHints {
|
|
||||||
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
|
|
||||||
}
|
|
||||||
|
|
||||||
varyValue := strings.Join(vary, ", ")
|
|
||||||
|
|
||||||
return &Writer{
|
|
||||||
config: config,
|
|
||||||
varyValue: varyValue,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL.
|
|
||||||
func (w *Writer) NewRequest() *Request {
|
|
||||||
return &Request{
|
|
||||||
writer: w,
|
|
||||||
result: make(http.Header),
|
|
||||||
maxAge: -1,
|
|
||||||
originHeaders: make(http.Header),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOriginHeaders sets the origin headers for the request.
|
|
||||||
func (r *Request) SetOriginHeaders(h http.Header) {
|
|
||||||
r.originHeaders = h
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetIsFallbackImage sets the Fallback-Image header to
|
|
||||||
// indicate that the fallback image was used.
|
|
||||||
func (r *Request) SetIsFallbackImage() {
|
|
||||||
// We set maxAge to FallbackImageTTL if it's explicitly passed
|
|
||||||
if r.writer.config.FallbackImageTTL < 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// However, we should not overwrite existing value if set (or greater than ours)
|
|
||||||
if r.maxAge < 0 || r.maxAge > r.writer.config.FallbackImageTTL {
|
|
||||||
r.maxAge = r.writer.config.FallbackImageTTL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetExpires sets the TTL from time
|
|
||||||
func (r *Request) SetExpires(expires *time.Time) {
|
|
||||||
if expires == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert current maxAge to time
|
|
||||||
currentMaxAgeTime := time.Now().Add(time.Duration(r.maxAge) * time.Second)
|
|
||||||
|
|
||||||
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
|
|
||||||
if r.maxAge < 0 || expires.Before(currentMaxAgeTime) {
|
|
||||||
r.maxAge = min(r.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVary sets the Vary header
|
|
||||||
func (r *Request) SetVary() {
|
|
||||||
if len(r.writer.varyValue) > 0 {
|
|
||||||
r.result.Set(httpheaders.Vary, r.writer.varyValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
|
|
||||||
func (r *Request) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
|
|
||||||
value := httpheaders.ContentDispositionValue(
|
|
||||||
originURL,
|
|
||||||
filename,
|
|
||||||
ext,
|
|
||||||
contentType,
|
|
||||||
returnAttachment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if value != "" {
|
|
||||||
r.result.Set(httpheaders.ContentDisposition, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Passthrough copies specified headers from the original response headers to the response headers.
|
|
||||||
func (r *Request) Passthrough(only ...string) {
|
|
||||||
httpheaders.Copy(r.originHeaders, r.result, only)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CopyFrom copies specified headers from the headers object. Please note that
|
|
||||||
// all the past operations may overwrite those values.
|
|
||||||
func (r *Request) CopyFrom(headers http.Header, only []string) {
|
|
||||||
httpheaders.Copy(headers, r.result, only)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetContentLength sets the Content-Length header
|
|
||||||
func (r *Request) SetContentLength(contentLength int) {
|
|
||||||
if contentLength < 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetContentType sets the Content-Type header
|
|
||||||
func (r *Request) SetContentType(mime string) {
|
|
||||||
r.result.Set(httpheaders.ContentType, mime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeCanonical sets the Link header with the canonical URL.
|
|
||||||
// It is mandatory for any response if enabled in the configuration.
|
|
||||||
func (r *Request) SetCanonical(url string) {
|
|
||||||
if !r.writer.config.SetCanonicalHeader {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
|
|
||||||
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
|
|
||||||
r.result.Set(httpheaders.Link, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCacheControl sets the Cache-Control header with the specified value.
|
|
||||||
func (r *Request) setCacheControl(value int) bool {
|
|
||||||
if value <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
r.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
|
|
||||||
func (r *Request) setCacheControlNoCache() {
|
|
||||||
r.result.Set(httpheaders.CacheControl, "no-cache")
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCacheControlPassthrough sets the Cache-Control header from the request
|
|
||||||
// if passthrough is enabled in the configuration.
|
|
||||||
func (r *Request) setCacheControlPassthrough() bool {
|
|
||||||
if !r.writer.config.CacheControlPassthrough || r.maxAge > 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" {
|
|
||||||
r.result.Set(httpheaders.CacheControl, val)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := r.originHeaders.Get(httpheaders.Expires); val != "" {
|
|
||||||
if t, err := time.Parse(http.TimeFormat, val); err == nil {
|
|
||||||
maxAge := max(0, int(time.Until(t).Seconds()))
|
|
||||||
return r.setCacheControl(maxAge)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCSP sets the Content-Security-Policy header to prevent script execution.
|
|
||||||
func (r *Request) setCSP() {
|
|
||||||
r.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes the headers to the response writer. It does not overwrite
|
|
||||||
// target headers, which were set outside the header writer.
|
|
||||||
func (r *Request) Write(rw http.ResponseWriter) {
|
|
||||||
// Then, let's try to set Cache-Control using priority order
|
|
||||||
switch {
|
|
||||||
case r.setCacheControl(r.maxAge): // First, try set explicit
|
|
||||||
case r.setCacheControlPassthrough(): // Try to pick up from request headers
|
|
||||||
case r.setCacheControl(r.writer.config.DefaultTTL): // Fallback to default value
|
|
||||||
default:
|
|
||||||
r.setCacheControlNoCache() // By default we use no-cache
|
|
||||||
}
|
|
||||||
|
|
||||||
r.setCSP()
|
|
||||||
|
|
||||||
// Copy all headers to the response without overwriting existing ones
|
|
||||||
httpheaders.CopyAll(r.result, rw.Header(), false)
|
|
||||||
}
|
|
14
imgproxy.go
14
imgproxy.go
@@ -11,7 +11,6 @@ import (
|
|||||||
landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing"
|
landinghandler "github.com/imgproxy/imgproxy/v3/handlers/landing"
|
||||||
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
processinghandler "github.com/imgproxy/imgproxy/v3/handlers/processing"
|
||||||
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
streamhandler "github.com/imgproxy/imgproxy/v3/handlers/stream"
|
||||||
"github.com/imgproxy/imgproxy/v3/headerwriter"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/memory"
|
"github.com/imgproxy/imgproxy/v3/memory"
|
||||||
"github.com/imgproxy/imgproxy/v3/monitoring/prometheus"
|
"github.com/imgproxy/imgproxy/v3/monitoring/prometheus"
|
||||||
@@ -34,7 +33,6 @@ type ImgproxyHandlers struct {
|
|||||||
|
|
||||||
// Imgproxy holds all the components needed for imgproxy to function
|
// Imgproxy holds all the components needed for imgproxy to function
|
||||||
type Imgproxy struct {
|
type Imgproxy struct {
|
||||||
headerWriter *headerwriter.Writer
|
|
||||||
semaphores *semaphores.Semaphores
|
semaphores *semaphores.Semaphores
|
||||||
fallbackImage auximageprovider.Provider
|
fallbackImage auximageprovider.Provider
|
||||||
watermarkImage auximageprovider.Provider
|
watermarkImage auximageprovider.Provider
|
||||||
@@ -46,11 +44,6 @@ type Imgproxy struct {
|
|||||||
|
|
||||||
// New creates a new imgproxy instance
|
// New creates a new imgproxy instance
|
||||||
func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
||||||
headerWriter, err := headerwriter.New(&config.HeaderWriter)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fetcher, err := fetcher.New(&config.Fetcher)
|
fetcher, err := fetcher.New(&config.Fetcher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -74,7 +67,6 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
imgproxy := &Imgproxy{
|
imgproxy := &Imgproxy{
|
||||||
headerWriter: headerWriter,
|
|
||||||
semaphores: semaphores,
|
semaphores: semaphores,
|
||||||
fallbackImage: fallbackImage,
|
fallbackImage: fallbackImage,
|
||||||
watermarkImage: watermarkImage,
|
watermarkImage: watermarkImage,
|
||||||
@@ -86,7 +78,7 @@ func New(ctx context.Context, config *Config) (*Imgproxy, error) {
|
|||||||
imgproxy.handlers.Health = healthhandler.New()
|
imgproxy.handlers.Health = healthhandler.New()
|
||||||
imgproxy.handlers.Landing = landinghandler.New()
|
imgproxy.handlers.Landing = landinghandler.New()
|
||||||
|
|
||||||
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, headerWriter, fetcher)
|
imgproxy.handlers.Stream, err = streamhandler.New(&config.Handlers.Stream, fetcher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -180,10 +172,6 @@ func (i *Imgproxy) startMemoryTicker(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Imgproxy) HeaderWriter() *headerwriter.Writer {
|
|
||||||
return i.headerWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Imgproxy) Semaphores() *semaphores.Semaphores {
|
func (i *Imgproxy) Semaphores() *semaphores.Semaphores {
|
||||||
return i.semaphores
|
return i.semaphores
|
||||||
}
|
}
|
||||||
|
@@ -238,7 +238,7 @@ func (s *ProcessingHandlerTestSuite) TestErrorSavingToSVG() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
|
func (s *ProcessingHandlerTestSuite) TestCacheControlPassthroughCacheControl() {
|
||||||
s.Config().HeaderWriter.CacheControlPassthrough = true
|
s.Config().Server.ResponseWriter.CacheControlPassthrough = true
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public")
|
rw.Header().Set(httpheaders.CacheControl, "max-age=1234, public")
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
"github.com/imgproxy/imgproxy/v3/ensure"
|
"github.com/imgproxy/imgproxy/v3/ensure"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config represents HTTP server config
|
// Config represents HTTP server config
|
||||||
@@ -18,7 +19,6 @@ type Config struct {
|
|||||||
PathPrefix string // Path prefix for the server
|
PathPrefix string // Path prefix for the server
|
||||||
MaxClients int // Maximum number of concurrent clients
|
MaxClients int // Maximum number of concurrent clients
|
||||||
ReadRequestTimeout time.Duration // Timeout for reading requests
|
ReadRequestTimeout time.Duration // Timeout for reading requests
|
||||||
WriteResponseTimeout time.Duration // Timeout for writing responses
|
|
||||||
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
|
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
|
||||||
GracefulTimeout time.Duration // Timeout for graceful shutdown
|
GracefulTimeout time.Duration // Timeout for graceful shutdown
|
||||||
CORSAllowOrigin string // CORS allowed origin
|
CORSAllowOrigin string // CORS allowed origin
|
||||||
@@ -27,6 +27,8 @@ type Config struct {
|
|||||||
SocketReusePort bool // Enable SO_REUSEPORT socket option
|
SocketReusePort bool // Enable SO_REUSEPORT socket option
|
||||||
HealthCheckPath string // Health check path from config
|
HealthCheckPath string // Health check path from config
|
||||||
|
|
||||||
|
ResponseWriter responsewriter.Config // Response writer config
|
||||||
|
|
||||||
// TODO: We are not sure where to put it yet
|
// TODO: We are not sure where to put it yet
|
||||||
FreeMemoryInterval time.Duration // Interval for freeing memory
|
FreeMemoryInterval time.Duration // Interval for freeing memory
|
||||||
LogMemStats bool // Log memory stats
|
LogMemStats bool // Log memory stats
|
||||||
@@ -41,7 +43,6 @@ func NewDefaultConfig() Config {
|
|||||||
MaxClients: 2048,
|
MaxClients: 2048,
|
||||||
ReadRequestTimeout: 10 * time.Second,
|
ReadRequestTimeout: 10 * time.Second,
|
||||||
KeepAliveTimeout: 10 * time.Second,
|
KeepAliveTimeout: 10 * time.Second,
|
||||||
WriteResponseTimeout: 10 * time.Second,
|
|
||||||
GracefulTimeout: 20 * time.Second,
|
GracefulTimeout: 20 * time.Second,
|
||||||
CORSAllowOrigin: "",
|
CORSAllowOrigin: "",
|
||||||
Secret: "",
|
Secret: "",
|
||||||
@@ -50,6 +51,8 @@ func NewDefaultConfig() Config {
|
|||||||
HealthCheckPath: "",
|
HealthCheckPath: "",
|
||||||
FreeMemoryInterval: 10 * time.Second,
|
FreeMemoryInterval: 10 * time.Second,
|
||||||
LogMemStats: false,
|
LogMemStats: false,
|
||||||
|
|
||||||
|
ResponseWriter: responsewriter.NewDefaultConfig(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +75,10 @@ func LoadConfigFromEnv(c *Config) (*Config, error) {
|
|||||||
c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second
|
c.FreeMemoryInterval = time.Duration(config.FreeMemoryInterval) * time.Second
|
||||||
c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0
|
c.LogMemStats = len(os.Getenv("IMGPROXY_LOG_MEM_STATS")) > 0
|
||||||
|
|
||||||
|
if _, err := responsewriter.LoadConfigFromEnv(&c.ResponseWriter); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,10 +96,6 @@ func (c *Config) Validate() error {
|
|||||||
return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
|
return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.WriteResponseTimeout <= 0 {
|
|
||||||
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.KeepAliveTimeout < 0 {
|
if c.KeepAliveTimeout < 0 {
|
||||||
return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
|
return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
|
||||||
}
|
}
|
||||||
|
@@ -23,10 +23,13 @@ func (r *Router) WithMonitoring(h RouteHandler) RouteHandler {
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
ctx, cancel, rw := monitoring.StartRequest(req.Context(), rw, req)
|
ctx, cancel, newRw := monitoring.StartRequest(req.Context(), rw.HTTPResponseWriter(), req)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// Replace rw.ResponseWriter with new one returned from monitoring
|
||||||
|
rw.SetHTTPResponseWriter(newRw)
|
||||||
|
|
||||||
return h(reqID, rw, req.WithContext(ctx))
|
return h(reqID, rw, req.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -37,7 +40,7 @@ func (r *Router) WithCORS(h RouteHandler) RouteHandler {
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
|
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
|
||||||
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
|
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
|
||||||
|
|
||||||
@@ -53,7 +56,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
|||||||
|
|
||||||
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
|
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
|
||||||
|
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
|
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
|
||||||
return h(reqID, rw, req)
|
return h(reqID, rw, req)
|
||||||
} else {
|
} else {
|
||||||
@@ -64,7 +67,7 @@ func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
|||||||
|
|
||||||
// WithPanic recovers panic and converts it to normal error
|
// WithPanic recovers panic and converts it to normal error
|
||||||
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
||||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
|
return func(reqID string, rw ResponseWriter, r *http.Request) (retErr error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// try to recover from panic
|
// try to recover from panic
|
||||||
rerr := recover()
|
rerr := recover()
|
||||||
@@ -94,7 +97,7 @@ func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
|||||||
// WithReportError handles error reporting.
|
// WithReportError handles error reporting.
|
||||||
// It should be placed after `WithMonitoring`, but before `WithPanic`.
|
// It should be placed after `WithMonitoring`, but before `WithPanic`.
|
||||||
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
|
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
// Open the error context
|
// Open the error context
|
||||||
ctx := errorreport.StartRequest(req)
|
ctx := errorreport.StartRequest(req)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
87
server/responsewriter/config.go
Normal file
87
server/responsewriter/config.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/ensure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds configuration for response writer
|
||||||
|
type Config struct {
|
||||||
|
SetCanonicalHeader bool // Indicates whether to set the canonical header
|
||||||
|
DefaultTTL int // Default Cache-Control max-age= value for cached images
|
||||||
|
FallbackImageTTL int // TTL for images served as fallbacks
|
||||||
|
CacheControlPassthrough bool // Passthrough the Cache-Control from the original response
|
||||||
|
VaryValue string // Value for Vary header
|
||||||
|
WriteResponseTimeout time.Duration // Timeout for response write operations
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultConfig returns a new Config instance with default values.
|
||||||
|
func NewDefaultConfig() Config {
|
||||||
|
return Config{
|
||||||
|
SetCanonicalHeader: false,
|
||||||
|
DefaultTTL: 31536000,
|
||||||
|
FallbackImageTTL: 0,
|
||||||
|
CacheControlPassthrough: false,
|
||||||
|
VaryValue: "",
|
||||||
|
WriteResponseTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfigFromEnv overrides configuration variables from environment
|
||||||
|
func LoadConfigFromEnv(c *Config) (*Config, error) {
|
||||||
|
c = ensure.Ensure(c, NewDefaultConfig)
|
||||||
|
|
||||||
|
c.SetCanonicalHeader = config.SetCanonicalHeader
|
||||||
|
c.DefaultTTL = config.TTL
|
||||||
|
c.FallbackImageTTL = config.FallbackImageTTL
|
||||||
|
c.CacheControlPassthrough = config.CacheControlPassthrough
|
||||||
|
c.WriteResponseTimeout = time.Duration(config.WriteResponseTimeout) * time.Second
|
||||||
|
|
||||||
|
vary := make([]string, 0)
|
||||||
|
|
||||||
|
if c.envEnableFormatDetection() {
|
||||||
|
vary = append(vary, "Accept")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.envEnableClientHints() {
|
||||||
|
vary = append(vary, "Sec-CH-DPR", "DPR", "Sec-CH-Width", "Width")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.VaryValue = strings.Join(vary, ", ")
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) envEnableFormatDetection() bool {
|
||||||
|
return config.AutoWebp ||
|
||||||
|
config.EnforceWebp ||
|
||||||
|
config.AutoAvif ||
|
||||||
|
config.EnforceAvif ||
|
||||||
|
config.AutoJxl ||
|
||||||
|
config.EnforceJxl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) envEnableClientHints() bool {
|
||||||
|
return config.EnableClientHints
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks config for errors
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
if c.DefaultTTL < 0 {
|
||||||
|
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.FallbackImageTTL < 0 {
|
||||||
|
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.WriteResponseTimeout <= 0 {
|
||||||
|
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
114
server/responsewriter/config_test.go
Normal file
114
server/responsewriter/config_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResponseWriterConfigSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ResponseWriterConfigSuite) SetupSuite() {
|
||||||
|
logrus.SetOutput(io.Discard)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ResponseWriterConfigSuite) TearDownSuite() {
|
||||||
|
logrus.SetOutput(os.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ResponseWriterConfigSuite) TestLoadingVaryValueFromEnv() {
|
||||||
|
defaultEnv := map[string]string{
|
||||||
|
"IMGPROXY_AUTO_WEBP": "",
|
||||||
|
"IMGPROXY_ENFORCE_WEBP": "",
|
||||||
|
"IMGPROXY_AUTO_AVIF": "",
|
||||||
|
"IMGPROXY_ENFORCE_AVIF": "",
|
||||||
|
"IMGPROXY_AUTO_JXL": "",
|
||||||
|
"IMGPROXY_ENFORCE_JXL": "",
|
||||||
|
"IMGPROXY_ENABLE_CLIENT_HINTS": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
env map[string]string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "AutoWebP",
|
||||||
|
env: map[string]string{"IMGPROXY_AUTO_WEBP": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EnforceWebP",
|
||||||
|
env: map[string]string{"IMGPROXY_ENFORCE_WEBP": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AutoAVIF",
|
||||||
|
env: map[string]string{"IMGPROXY_AUTO_AVIF": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EnforceAVIF",
|
||||||
|
env: map[string]string{"IMGPROXY_ENFORCE_AVIF": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AutoJXL",
|
||||||
|
env: map[string]string{"IMGPROXY_AUTO_JXL": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EnforceJXL",
|
||||||
|
env: map[string]string{"IMGPROXY_ENFORCE_JXL": "true"},
|
||||||
|
expected: "Accept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EnableClientHints",
|
||||||
|
env: map[string]string{"IMGPROXY_ENABLE_CLIENT_HINTS": "true"},
|
||||||
|
expected: "Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Combined",
|
||||||
|
env: map[string]string{
|
||||||
|
"IMGPROXY_AUTO_WEBP": "true",
|
||||||
|
"IMGPROXY_ENABLE_CLIENT_HINTS": "true",
|
||||||
|
},
|
||||||
|
expected: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
s.Run(fmt.Sprintf("%v", tc.env), func() {
|
||||||
|
// Set default environment variables
|
||||||
|
for key, value := range defaultEnv {
|
||||||
|
s.T().Setenv(key, value)
|
||||||
|
}
|
||||||
|
// Set environment variables
|
||||||
|
for key, value := range tc.env {
|
||||||
|
s.T().Setenv(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove when we removed global config
|
||||||
|
config.Reset()
|
||||||
|
config.Configure()
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
cfg, err := LoadConfigFromEnv(nil)
|
||||||
|
|
||||||
|
// Assert expected values
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(tc.expected, cfg.VaryValue)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWriterConfig(t *testing.T) {
|
||||||
|
suite.Run(t, new(ResponseWriterConfigSuite))
|
||||||
|
}
|
30
server/responsewriter/factory.go
Normal file
30
server/responsewriter/factory.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Factory is a struct that creates response writers.
|
||||||
|
type Factory struct {
|
||||||
|
config *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFactory(config *Config) (*Factory, error) {
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Factory{config}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriter wraps [http.ResponseWriter] into [Writer].
|
||||||
|
func (f *Factory) NewWriter(rw http.ResponseWriter) *Writer {
|
||||||
|
w := &Writer{
|
||||||
|
config: f.config,
|
||||||
|
result: make(http.Header),
|
||||||
|
originHeaders: make(http.Header),
|
||||||
|
maxAge: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.SetHTTPResponseWriter(rw)
|
||||||
|
|
||||||
|
return w
|
||||||
|
}
|
226
server/responsewriter/writer.go
Normal file
226
server/responsewriter/writer.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Just aliases for [http.ResponseWriter] and [http.ResponseController].
|
||||||
|
// We need them to make them private in [Writer] so they can't be accessed directly.
|
||||||
|
type httpResponseWriter = http.ResponseWriter
|
||||||
|
type httpResponseController = *http.ResponseController
|
||||||
|
|
||||||
|
// Writer is an implementation of [http.ResponseWriter] with additional
|
||||||
|
// functionality for managing response headers.
|
||||||
|
type Writer struct {
|
||||||
|
httpResponseWriter
|
||||||
|
httpResponseController
|
||||||
|
|
||||||
|
config *Config // Configuration for the writer
|
||||||
|
originHeaders http.Header // Original response headers
|
||||||
|
result http.Header // Headers to be written to the response
|
||||||
|
maxAge int // Current max age for Cache-Control header
|
||||||
|
|
||||||
|
beforeWriteOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPResponseWriter returns the underlying http.ResponseWriter.
|
||||||
|
func (w *Writer) HTTPResponseWriter() http.ResponseWriter {
|
||||||
|
return w.httpResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHTTPResponseWriter replaces the underlying http.ResponseWriter.
|
||||||
|
func (w *Writer) SetHTTPResponseWriter(rw http.ResponseWriter) {
|
||||||
|
w.httpResponseWriter = rw
|
||||||
|
w.httpResponseController = http.NewResponseController(rw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOriginHeaders sets the origin headers for the request.
|
||||||
|
func (w *Writer) SetOriginHeaders(h http.Header) {
|
||||||
|
w.originHeaders = h
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIsFallbackImage sets the Fallback-Image header to
|
||||||
|
// indicate that the fallback image was used.
|
||||||
|
func (w *Writer) SetIsFallbackImage() {
|
||||||
|
// We set maxAge to FallbackImageTTL if it's explicitly passed
|
||||||
|
if w.config.FallbackImageTTL < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// However, we should not overwrite existing value if set (or greater than ours)
|
||||||
|
if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL {
|
||||||
|
w.maxAge = w.config.FallbackImageTTL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpires sets the TTL from time
|
||||||
|
func (w *Writer) SetExpires(expires *time.Time) {
|
||||||
|
if expires == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert current maxAge to time
|
||||||
|
currentMaxAgeTime := time.Now().Add(time.Duration(w.maxAge) * time.Second)
|
||||||
|
|
||||||
|
// If maxAge outlives expires or was not set, we'll use expires as maxAge.
|
||||||
|
if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
|
||||||
|
w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVary sets the Vary header
|
||||||
|
func (w *Writer) SetVary() {
|
||||||
|
if val := w.config.VaryValue; len(val) > 0 {
|
||||||
|
w.result.Set(httpheaders.Vary, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentDisposition sets the Content-Disposition header, passthrough to ContentDispositionValue
|
||||||
|
func (w *Writer) SetContentDisposition(originURL, filename, ext, contentType string, returnAttachment bool) {
|
||||||
|
value := httpheaders.ContentDispositionValue(
|
||||||
|
originURL,
|
||||||
|
filename,
|
||||||
|
ext,
|
||||||
|
contentType,
|
||||||
|
returnAttachment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if value != "" {
|
||||||
|
w.result.Set(httpheaders.ContentDisposition, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Passthrough copies specified headers from the original response headers to the response headers.
|
||||||
|
func (w *Writer) Passthrough(only ...string) {
|
||||||
|
httpheaders.Copy(w.originHeaders, w.result, only)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyFrom copies specified headers from the headers object. Please note that
|
||||||
|
// all the past operations may overwrite those values.
|
||||||
|
func (w *Writer) CopyFrom(headers http.Header, only []string) {
|
||||||
|
httpheaders.Copy(headers, w.result, only)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentLength sets the Content-Length header
|
||||||
|
func (w *Writer) SetContentLength(contentLength int) {
|
||||||
|
if contentLength < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.result.Set(httpheaders.ContentLength, strconv.Itoa(contentLength))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentType sets the Content-Type header
|
||||||
|
func (w *Writer) SetContentType(mime string) {
|
||||||
|
w.result.Set(httpheaders.ContentType, mime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCanonical sets the Link header with the canonical URL.
|
||||||
|
// It is mandatory for any response if enabled in the configuration.
|
||||||
|
func (w *Writer) SetCanonical(url string) {
|
||||||
|
if !w.config.SetCanonicalHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") {
|
||||||
|
value := fmt.Sprintf(`<%s>; rel="canonical"`, url)
|
||||||
|
w.result.Set(httpheaders.Link, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCacheControl sets the Cache-Control header with the specified value.
|
||||||
|
func (w *Writer) setCacheControl(value int) bool {
|
||||||
|
if value <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
w.result.Set(httpheaders.CacheControl, fmt.Sprintf("max-age=%d, public", value))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCacheControlNoCache sets the Cache-Control header to no-cache (default).
|
||||||
|
func (w *Writer) setCacheControlNoCache() {
|
||||||
|
w.result.Set(httpheaders.CacheControl, "no-cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCacheControlPassthrough sets the Cache-Control header from the request
|
||||||
|
// if passthrough is enabled in the configuration.
|
||||||
|
func (w *Writer) setCacheControlPassthrough() bool {
|
||||||
|
if !w.config.CacheControlPassthrough || w.maxAge > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := w.originHeaders.Get(httpheaders.CacheControl); val != "" {
|
||||||
|
w.result.Set(httpheaders.CacheControl, val)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := w.originHeaders.Get(httpheaders.Expires); val != "" {
|
||||||
|
if t, err := time.Parse(http.TimeFormat, val); err == nil {
|
||||||
|
maxAge := max(0, int(time.Until(t).Seconds()))
|
||||||
|
return w.setCacheControl(maxAge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCSP sets the Content-Security-Policy header to prevent script execution.
|
||||||
|
func (w *Writer) setCSP() {
|
||||||
|
w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushHeaders writes the headers to the response writer. It does not overwrite
|
||||||
|
// target headers, which were set outside the header writer.
|
||||||
|
func (w *Writer) flushHeaders() {
|
||||||
|
// Then, let's try to set Cache-Control using priority order
|
||||||
|
switch {
|
||||||
|
case w.setCacheControl(w.maxAge): // First, try set explicit
|
||||||
|
case w.setCacheControlPassthrough(): // Try to pick up from request headers
|
||||||
|
case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value
|
||||||
|
default:
|
||||||
|
w.setCacheControlNoCache() // By default we use no-cache
|
||||||
|
}
|
||||||
|
|
||||||
|
w.setCSP()
|
||||||
|
|
||||||
|
// Copy all headers to the response without overwriting existing ones
|
||||||
|
httpheaders.CopyAll(w.result, w.Header(), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beforeWrite is called before [WriteHeader] and [Write]
|
||||||
|
func (w *Writer) beforeWrite() {
|
||||||
|
w.beforeWriteOnce.Do(func() {
|
||||||
|
// We're going to start writing response.
|
||||||
|
// Set write deadline.
|
||||||
|
w.SetWriteDeadline(time.Now().Add(w.config.WriteResponseTimeout))
|
||||||
|
|
||||||
|
// Flush headers before we write anything
|
||||||
|
w.flushHeaders()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader writes the HTTP response header.
|
||||||
|
//
|
||||||
|
// It ensures that all headers are flushed before writing the status code.
|
||||||
|
func (w *Writer) WriteHeader(statusCode int) {
|
||||||
|
w.beforeWrite()
|
||||||
|
|
||||||
|
w.httpResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes the HTTP response body.
|
||||||
|
//
|
||||||
|
// It ensures that all headers are flushed before writing the body.
|
||||||
|
func (w *Writer) Write(b []byte) (int, error) {
|
||||||
|
w.beforeWrite()
|
||||||
|
|
||||||
|
return w.httpResponseWriter.Write(b)
|
||||||
|
}
|
@@ -1,4 +1,4 @@
|
|||||||
package headerwriter
|
package responsewriter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HeaderWriterSuite struct {
|
type ResponseWriterSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,16 +22,18 @@ type writerTestCase struct {
|
|||||||
req http.Header
|
req http.Header
|
||||||
res http.Header
|
res http.Header
|
||||||
config Config
|
config Config
|
||||||
fn func(*Request)
|
fn func(*Writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HeaderWriterSuite) TestHeaderCases() {
|
func (s *ResponseWriterSuite) TestHeaderCases() {
|
||||||
expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC)
|
expires := time.Date(2030, 8, 1, 0, 0, 0, 0, time.UTC)
|
||||||
expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds()))
|
expiresSeconds := strconv.Itoa(int(time.Until(expires).Seconds()))
|
||||||
|
|
||||||
shortExpires := time.Now().Add(10 * time.Second)
|
shortExpires := time.Now().Add(10 * time.Second)
|
||||||
shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds()))
|
shortExpiresSeconds := strconv.Itoa(int(time.Until(shortExpires).Seconds()))
|
||||||
|
|
||||||
|
writeResponseTimeout := 10 * time.Second
|
||||||
|
|
||||||
tt := []writerTestCase{
|
tt := []writerTestCase{
|
||||||
{
|
{
|
||||||
name: "MinimalHeaders",
|
name: "MinimalHeaders",
|
||||||
@@ -44,8 +46,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
SetCanonicalHeader: false,
|
SetCanonicalHeader: false,
|
||||||
DefaultTTL: 0,
|
DefaultTTL: 0,
|
||||||
CacheControlPassthrough: false,
|
CacheControlPassthrough: false,
|
||||||
EnableClientHints: false,
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
SetVaryAccept: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -60,6 +61,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
config: Config{
|
config: Config{
|
||||||
CacheControlPassthrough: true,
|
CacheControlPassthrough: true,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -74,6 +76,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
config: Config{
|
config: Config{
|
||||||
CacheControlPassthrough: true,
|
CacheControlPassthrough: true,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -88,6 +91,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
config: Config{
|
config: Config{
|
||||||
CacheControlPassthrough: true,
|
CacheControlPassthrough: true,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -99,10 +103,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
SetCanonicalHeader: true,
|
SetCanonicalHeader: true,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetCanonical("https://example.com/image.jpg")
|
w.SetCanonical("https://example.com/image.jpg")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -114,8 +119,9 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
SetCanonicalHeader: true,
|
SetCanonicalHeader: true,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -126,10 +132,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
SetCanonicalHeader: false,
|
SetCanonicalHeader: false,
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetCanonical("https://example.com/image.jpg")
|
w.SetCanonical("https://example.com/image.jpg")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -141,10 +148,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
FallbackImageTTL: 1,
|
FallbackImageTTL: 1,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetIsFallbackImage()
|
w.SetIsFallbackImage()
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -156,9 +164,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
DefaultTTL: math.MaxInt32,
|
DefaultTTL: math.MaxInt32,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetExpires(&expires)
|
w.SetExpires(&expires)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -170,10 +179,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
DefaultTTL: math.MaxInt32,
|
DefaultTTL: math.MaxInt32,
|
||||||
FallbackImageTTL: 600,
|
FallbackImageTTL: 600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetIsFallbackImage()
|
w.SetIsFallbackImage()
|
||||||
w.SetExpires(&shortExpires)
|
w.SetExpires(&shortExpires)
|
||||||
},
|
},
|
||||||
@@ -187,10 +197,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
EnableClientHints: true,
|
VaryValue: "Accept, Sec-CH-DPR, DPR, Sec-CH-Width, Width",
|
||||||
SetVaryAccept: true,
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetVary()
|
w.SetVary()
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -204,8 +214,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.CacheControl: []string{"no-cache"},
|
httpheaders.CacheControl: []string{"no-cache"},
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{},
|
config: Config{
|
||||||
fn: func(w *Request) {
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
|
},
|
||||||
|
fn: func(w *Writer) {
|
||||||
w.Passthrough("X-Test")
|
w.Passthrough("X-Test")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -217,8 +229,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.CacheControl: []string{"no-cache"},
|
httpheaders.CacheControl: []string{"no-cache"},
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{},
|
config: Config{
|
||||||
fn: func(w *Request) {
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
|
},
|
||||||
|
fn: func(w *Writer) {
|
||||||
h := http.Header{}
|
h := http.Header{}
|
||||||
h.Set("X-From", "baz")
|
h.Set("X-From", "baz")
|
||||||
w.CopyFrom(h, []string{"X-From"})
|
w.CopyFrom(h, []string{"X-From"})
|
||||||
@@ -232,8 +246,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.CacheControl: []string{"no-cache"},
|
httpheaders.CacheControl: []string{"no-cache"},
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{},
|
config: Config{
|
||||||
fn: func(w *Request) {
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
|
},
|
||||||
|
fn: func(w *Writer) {
|
||||||
w.SetContentLength(123)
|
w.SetContentLength(123)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -245,8 +261,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.CacheControl: []string{"no-cache"},
|
httpheaders.CacheControl: []string{"no-cache"},
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{},
|
config: Config{
|
||||||
fn: func(w *Request) {
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
|
},
|
||||||
|
fn: func(w *Writer) {
|
||||||
w.SetContentType("image/png")
|
w.SetContentType("image/png")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -258,58 +276,30 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
||||||
},
|
},
|
||||||
config: Config{
|
config: Config{
|
||||||
DefaultTTL: 3600,
|
DefaultTTL: 3600,
|
||||||
|
WriteResponseTimeout: writeResponseTimeout,
|
||||||
},
|
},
|
||||||
fn: func(w *Request) {
|
fn: func(w *Writer) {
|
||||||
w.SetExpires(nil)
|
w.SetExpires(nil)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "WriteVaryAcceptOnly",
|
|
||||||
req: http.Header{},
|
|
||||||
res: http.Header{
|
|
||||||
httpheaders.Vary: []string{"Accept"},
|
|
||||||
httpheaders.CacheControl: []string{"no-cache"},
|
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
|
||||||
},
|
|
||||||
config: Config{
|
|
||||||
SetVaryAccept: true,
|
|
||||||
},
|
|
||||||
fn: func(w *Request) {
|
|
||||||
w.SetVary()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WriteVaryClientHintsOnly",
|
|
||||||
req: http.Header{},
|
|
||||||
res: http.Header{
|
|
||||||
httpheaders.Vary: []string{"Sec-CH-DPR, DPR, Sec-CH-Width, Width"},
|
|
||||||
httpheaders.CacheControl: []string{"no-cache"},
|
|
||||||
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
|
|
||||||
},
|
|
||||||
config: Config{
|
|
||||||
EnableClientHints: true,
|
|
||||||
},
|
|
||||||
fn: func(w *Request) {
|
|
||||||
w.SetVary()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
s.Run(tc.name, func() {
|
s.Run(tc.name, func() {
|
||||||
factory, err := New(&tc.config)
|
factory, err := NewFactory(&tc.config)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
writer := factory.NewRequest()
|
r := httptest.NewRecorder()
|
||||||
|
|
||||||
|
writer := factory.NewWriter(r)
|
||||||
writer.SetOriginHeaders(tc.req)
|
writer.SetOriginHeaders(tc.req)
|
||||||
|
|
||||||
if tc.fn != nil {
|
if tc.fn != nil {
|
||||||
tc.fn(writer)
|
tc.fn(writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
r := httptest.NewRecorder()
|
writer.WriteHeader(http.StatusOK)
|
||||||
writer.Write(r)
|
|
||||||
|
|
||||||
s.Require().Equal(tc.res, r.Header())
|
s.Require().Equal(tc.res, r.Header())
|
||||||
})
|
})
|
||||||
@@ -317,5 +307,5 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHeaderWriter(t *testing.T) {
|
func TestHeaderWriter(t *testing.T) {
|
||||||
suite.Run(t, new(HeaderWriterSuite))
|
suite.Run(t, new(ResponseWriterSuite))
|
||||||
}
|
}
|
@@ -11,6 +11,7 @@ import (
|
|||||||
nanoid "github.com/matoous/go-nanoid/v2"
|
nanoid "github.com/matoous/go-nanoid/v2"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -23,8 +24,10 @@ var (
|
|||||||
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResponseWriter = *responsewriter.Writer
|
||||||
|
|
||||||
// RouteHandler is a function that handles HTTP requests.
|
// RouteHandler is a function that handles HTTP requests.
|
||||||
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
|
type RouteHandler func(string, ResponseWriter, *http.Request) error
|
||||||
|
|
||||||
// Middleware is a function that wraps a RouteHandler with additional functionality.
|
// Middleware is a function that wraps a RouteHandler with additional functionality.
|
||||||
type Middleware func(next RouteHandler) RouteHandler
|
type Middleware func(next RouteHandler) RouteHandler
|
||||||
@@ -40,6 +43,9 @@ type route struct {
|
|||||||
|
|
||||||
// Router is responsible for routing HTTP requests
|
// Router is responsible for routing HTTP requests
|
||||||
type Router struct {
|
type Router struct {
|
||||||
|
// Response writers factory
|
||||||
|
rwFactory *responsewriter.Factory
|
||||||
|
|
||||||
// config represents the server configuration
|
// config represents the server configuration
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
@@ -53,7 +59,15 @@ func NewRouter(config *Config) (*Router, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Router{config: config}, nil
|
rwf, err := responsewriter.NewFactory(&config.ResponseWriter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Router{
|
||||||
|
rwFactory: rwf,
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// add adds an abitary route to the router
|
// add adds an abitary route to the router
|
||||||
@@ -114,8 +128,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
req, timeoutCancel := startRequestTimer(req)
|
req, timeoutCancel := startRequestTimer(req)
|
||||||
defer timeoutCancel()
|
defer timeoutCancel()
|
||||||
|
|
||||||
// Create the response writer which times out on write
|
// Create the [ResponseWriter]
|
||||||
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
|
rww := r.rwFactory.NewWriter(rw)
|
||||||
|
|
||||||
// Get/create request ID
|
// Get/create request ID
|
||||||
reqID := r.getRequestID(req)
|
reqID := r.getRequestID(req)
|
||||||
@@ -123,8 +137,8 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// Replace request IP from headers
|
// Replace request IP from headers
|
||||||
r.replaceRemoteAddr(req)
|
r.replaceRemoteAddr(req)
|
||||||
|
|
||||||
rw.Header().Set(httpheaders.Server, defaultServerName)
|
rww.Header().Set(httpheaders.Server, defaultServerName)
|
||||||
rw.Header().Set(httpheaders.XRequestID, reqID)
|
rww.Header().Set(httpheaders.XRequestID, reqID)
|
||||||
|
|
||||||
for _, rr := range r.routes {
|
for _, rr := range r.routes {
|
||||||
if !rr.isMatch(req) {
|
if !rr.isMatch(req) {
|
||||||
@@ -138,18 +152,18 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
LogRequest(reqID, req)
|
LogRequest(reqID, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
rr.handler(reqID, rw, req)
|
rr.handler(reqID, rww, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Means that we have not found matching route
|
// Means that we have not found matching route
|
||||||
LogRequest(reqID, req)
|
LogRequest(reqID, req)
|
||||||
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
|
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
|
||||||
r.NotFoundHandler(reqID, rw, req)
|
r.NotFoundHandler(reqID, rww, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotFoundHandler is default 404 handler
|
// NotFoundHandler is default 404 handler
|
||||||
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
func (r *Router) NotFoundHandler(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||||
@@ -158,7 +172,7 @@ func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OkHandler is a default 200 OK handler
|
// OkHandler is a default 200 OK handler
|
||||||
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
func (r *Router) OkHandler(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||||
|
@@ -30,7 +30,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
|||||||
var capturedMethod string
|
var capturedMethod string
|
||||||
var capturedPath string
|
var capturedPath string
|
||||||
|
|
||||||
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
getHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
capturedMethod = req.Method
|
capturedMethod = req.Method
|
||||||
capturedPath = req.URL.Path
|
capturedPath = req.URL.Path
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
@@ -38,7 +38,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
optionsHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
capturedMethod = req.Method
|
capturedMethod = req.Method
|
||||||
capturedPath = req.URL.Path
|
capturedPath = req.URL.Path
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
@@ -46,7 +46,7 @@ func (s *RouterTestSuite) TestHTTPMethods() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
headHandler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
capturedMethod = req.Method
|
capturedMethod = req.Method
|
||||||
capturedPath = req.URL.Path
|
capturedPath = req.URL.Path
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
@@ -114,20 +114,20 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
|
|||||||
var order []string
|
var order []string
|
||||||
|
|
||||||
middleware1 := func(next RouteHandler) RouteHandler {
|
middleware1 := func(next RouteHandler) RouteHandler {
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
order = append(order, "middleware1")
|
order = append(order, "middleware1")
|
||||||
return next(reqID, rw, req)
|
return next(reqID, rw, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
middleware2 := func(next RouteHandler) RouteHandler {
|
middleware2 := func(next RouteHandler) RouteHandler {
|
||||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
return func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
order = append(order, "middleware2")
|
order = append(order, "middleware2")
|
||||||
return next(reqID, rw, req)
|
return next(reqID, rw, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
order = append(order, "handler")
|
order = append(order, "handler")
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
return nil
|
return nil
|
||||||
@@ -146,7 +146,7 @@ func (s *RouterTestSuite) TestMiddlewareOrder() {
|
|||||||
|
|
||||||
// TestServeHTTP tests ServeHTTP method
|
// TestServeHTTP tests ServeHTTP method
|
||||||
func (s *RouterTestSuite) TestServeHTTP() {
|
func (s *RouterTestSuite) TestServeHTTP() {
|
||||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.Header().Set("Custom-Header", "test-value")
|
rw.Header().Set("Custom-Header", "test-value")
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
rw.Write([]byte("success"))
|
rw.Write([]byte("success"))
|
||||||
@@ -169,7 +169,7 @@ func (s *RouterTestSuite) TestServeHTTP() {
|
|||||||
|
|
||||||
// TestRequestID checks request ID generation and validation
|
// TestRequestID checks request ID generation and validation
|
||||||
func (s *RouterTestSuite) TestRequestID() {
|
func (s *RouterTestSuite) TestRequestID() {
|
||||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -209,7 +209,7 @@ func (s *RouterTestSuite) TestRequestID() {
|
|||||||
|
|
||||||
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
|
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
|
||||||
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
||||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -229,7 +229,7 @@ func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
|||||||
// Test IP address handling
|
// Test IP address handling
|
||||||
func (s *RouterTestSuite) TestReplaceIP() {
|
func (s *RouterTestSuite) TestReplaceIP() {
|
||||||
var capturedRemoteAddr string
|
var capturedRemoteAddr string
|
||||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
handler := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
capturedRemoteAddr = req.RemoteAddr
|
capturedRemoteAddr = req.RemoteAddr
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
return nil
|
return nil
|
||||||
@@ -298,7 +298,7 @@ func (s *RouterTestSuite) TestReplaceIP() {
|
|||||||
// TestRouteOrder checks exact/non-exact insertion order
|
// TestRouteOrder checks exact/non-exact insertion order
|
||||||
func (s *RouterTestSuite) TestRouteOrder() {
|
func (s *RouterTestSuite) TestRouteOrder() {
|
||||||
|
|
||||||
h := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
h := func(reqID string, rw ResponseWriter, req *http.Request) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -29,10 +29,14 @@ func (s *ServerTestSuite) SetupTest() {
|
|||||||
s.blankRouter = r
|
s.blankRouter = r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
func (s *ServerTestSuite) mockHandler(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) wrapRW(rw http.ResponseWriter) ResponseWriter {
|
||||||
|
return s.blankRouter.rwFactory.NewWriter(rw)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
|
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
|
||||||
ctx, cancel := context.WithCancel(s.T().Context())
|
ctx, cancel := context.WithCancel(s.T().Context())
|
||||||
|
|
||||||
@@ -121,7 +125,7 @@ func (s *ServerTestSuite) TestWithCORS() {
|
|||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
wrappedHandler("test-req-id", rw, req)
|
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
|
|
||||||
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
|
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
|
||||||
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
|
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
|
||||||
@@ -170,7 +174,7 @@ func (s *ServerTestSuite) TestWithSecret() {
|
|||||||
}
|
}
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
err = wrappedHandler("test-req-id", rw, req)
|
err = wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
|
|
||||||
if tt.expectError {
|
if tt.expectError {
|
||||||
s.Require().Error(err)
|
s.Require().Error(err)
|
||||||
@@ -182,7 +186,7 @@ func (s *ServerTestSuite) TestWithSecret() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) TestIntoSuccess() {
|
func (s *ServerTestSuite) TestIntoSuccess() {
|
||||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -192,14 +196,14 @@ func (s *ServerTestSuite) TestIntoSuccess() {
|
|||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
wrappedHandler("test-req-id", rw, req)
|
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
|
|
||||||
s.Equal(http.StatusOK, rw.Code)
|
s.Equal(http.StatusOK, rw.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) TestIntoWithError() {
|
func (s *ServerTestSuite) TestIntoWithError() {
|
||||||
testError := errors.New("test error")
|
testError := errors.New("test error")
|
||||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
return testError
|
return testError
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,7 +212,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
|
|||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
wrappedHandler("test-req-id", rw, req)
|
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
|
|
||||||
s.Equal(http.StatusInternalServerError, rw.Code)
|
s.Equal(http.StatusInternalServerError, rw.Code)
|
||||||
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
|
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
|
||||||
@@ -216,7 +220,7 @@ func (s *ServerTestSuite) TestIntoWithError() {
|
|||||||
|
|
||||||
func (s *ServerTestSuite) TestIntoPanicWithError() {
|
func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||||
testError := errors.New("panic error")
|
testError := errors.New("panic error")
|
||||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
panic(testError)
|
panic(testError)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,7 +230,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
|
|||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
s.NotPanics(func() {
|
s.NotPanics(func() {
|
||||||
err := wrappedHandler("test-req-id", rw, req)
|
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
s.Require().Error(err, "panic error")
|
s.Require().Error(err, "panic error")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -234,7 +238,7 @@ func (s *ServerTestSuite) TestIntoPanicWithError() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
||||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
panic(http.ErrAbortHandler)
|
panic(http.ErrAbortHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,12 +249,12 @@ func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
|||||||
|
|
||||||
// Should re-panic with ErrAbortHandler
|
// Should re-panic with ErrAbortHandler
|
||||||
s.Panics(func() {
|
s.Panics(func() {
|
||||||
wrappedHandler("test-req-id", rw, req)
|
wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
||||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
mockHandler := func(reqID string, rw ResponseWriter, r *http.Request) error {
|
||||||
panic("string panic")
|
panic("string panic")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,7 +265,7 @@ func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
|||||||
|
|
||||||
// Should re-panic with non-error panics
|
// Should re-panic with non-error panics
|
||||||
s.NotPanics(func() {
|
s.NotPanics(func() {
|
||||||
err := wrappedHandler("test-req-id", rw, req)
|
err := wrappedHandler("test-req-id", s.wrapRW(rw), req)
|
||||||
s.Require().Error(err, "string panic")
|
s.Require().Error(err, "string panic")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@@ -1,47 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// timeoutResponse manages response writer with timeout. It has
|
|
||||||
// timeout on all write methods.
|
|
||||||
type timeoutResponse struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
controller *http.ResponseController
|
|
||||||
timeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTimeoutResponse creates a new timeoutResponse
|
|
||||||
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
|
|
||||||
return &timeoutResponse{
|
|
||||||
ResponseWriter: rw,
|
|
||||||
controller: http.NewResponseController(rw),
|
|
||||||
timeout: timeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write implements http.ResponseWriter.Write
|
|
||||||
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
rw.withWriteDeadline(func() {
|
|
||||||
n, err = rw.ResponseWriter.Write(b)
|
|
||||||
})
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// withWriteDeadline executes a Write* function with a deadline
|
|
||||||
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
|
||||||
deadline := time.Now().Add(rw.timeout)
|
|
||||||
|
|
||||||
// Set write deadline
|
|
||||||
rw.controller.SetWriteDeadline(deadline)
|
|
||||||
|
|
||||||
// Reset write deadline after method has finished
|
|
||||||
defer rw.controller.SetWriteDeadline(time.Time{})
|
|
||||||
f()
|
|
||||||
}
|
|
Reference in New Issue
Block a user