IMG-51: router -> server ns, handlers ns, added error to handler ret val (#1494)

* Introduced server, handlers, error ret in handlerfn

* Server struct with tests

* replace checkErr with return
This commit is contained in:
Victor Sokolov
2025-08-20 14:31:11 +02:00
committed by GitHub
parent 0ddefe1b85
commit 15bd00b221
29 changed files with 1483 additions and 561 deletions

View File

@@ -7,11 +7,23 @@ import (
"github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/ierrors"
) )
// Monitoring error categories
const (
categoryTimeout = "timeout"
categoryImageDataSize = "image_data_size"
categoryPathParsing = "path_parsing"
categorySecurity = "security"
categoryQueue = "queue"
categoryDownload = "download"
categoryProcessing = "processing"
categoryIO = "IO"
categoryStreaming = "streaming"
)
type ( type (
ResponseWriteError struct{ error } ResponseWriteError struct{ error }
InvalidURLError string InvalidURLError string
TooManyRequestsError struct{} TooManyRequestsError struct{}
InvalidSecretError struct{}
) )
func newResponseWriteError(cause error) *ierrors.Error { func newResponseWriteError(cause error) *ierrors.Error {
@@ -53,15 +65,3 @@ func newTooManyRequestsError() error {
} }
func (e TooManyRequestsError) Error() string { return "Too many requests" } func (e TooManyRequestsError) Error() string { return "Too many requests" }
func newInvalidSecretError() error {
return ierrors.Wrap(
InvalidSecretError{},
1,
ierrors.WithStatusCode(http.StatusForbidden),
ierrors.WithPublicMessage("Forbidden"),
ierrors.WithShouldReport(false),
)
}
func (e InvalidSecretError) Error() string { return "Invalid secret" }

46
handlers/health.go Normal file
View File

@@ -0,0 +1,46 @@
package handlers
import (
"net/http"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/vips"
)
var imgproxyIsRunningMsg = []byte("imgproxy is running")
// HealthHandler handles the health check requests
func HealthHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
var (
status int
msg []byte
ierr *ierrors.Error
)
if err := vips.Health(); err == nil {
status = http.StatusOK
msg = imgproxyIsRunningMsg
} else {
status = http.StatusInternalServerError
msg = []byte("Error")
ierr = ierrors.Wrap(err, 1)
}
if len(msg) == 0 {
msg = []byte{' '}
}
// Log response only if something went wrong
if ierr != nil {
server.LogResponse(reqID, r, status, ierr)
}
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.Header().Set(httpheaders.CacheControl, "no-cache")
rw.WriteHeader(status)
rw.Write(msg)
return nil
}

32
handlers/health_test.go Normal file
View File

@@ -0,0 +1,32 @@
package handlers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
func TestHealthHandler(t *testing.T) {
// Create a ResponseRecorder to record the response
rr := httptest.NewRecorder()
// Call the handler function directly (no need for actual HTTP request)
HealthHandler("test-req-id", rr, nil)
// Check that we get a valid response (either 200 or 500 depending on vips state)
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
// Check headers are set correctly
assert.Equal(t, "text/plain", rr.Header().Get(httpheaders.ContentType))
assert.Equal(t, "no-cache", rr.Header().Get(httpheaders.CacheControl))
// Verify response format and content
body := rr.Body.String()
assert.NotEmpty(t, body)
assert.Equal(t, "imgproxy is running", body)
}

View File

@@ -1,6 +1,10 @@
package main package handlers
import "net/http" import (
"net/http"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
var landingTmpl = []byte(` var landingTmpl = []byte(`
<!doctype html> <!doctype html>
@@ -39,8 +43,10 @@ var landingTmpl = []byte(`
</html> </html>
`) `)
func handleLanding(reqID string, rw http.ResponseWriter, r *http.Request) { // LandingHandler handles the landing page requests
rw.Header().Set("Content-Type", "text/html") func LandingHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
rw.WriteHeader(200) rw.Header().Set(httpheaders.ContentType, "text/html")
rw.WriteHeader(http.StatusOK)
rw.Write(landingTmpl) rw.Write(landingTmpl)
return nil
} }

View File

@@ -17,6 +17,7 @@ const (
AltSvc = "Alt-Svc" AltSvc = "Alt-Svc"
Authorization = "Authorization" Authorization = "Authorization"
CacheControl = "Cache-Control" CacheControl = "Cache-Control"
CFConnectingIP = "CF-Connecting-IP"
Connection = "Connection" Connection = "Connection"
ContentDisposition = "Content-Disposition" ContentDisposition = "Content-Disposition"
ContentEncoding = "Content-Encoding" ContentEncoding = "Content-Encoding"
@@ -56,6 +57,7 @@ const (
Vary = "Vary" Vary = "Vary"
Via = "Via" Via = "Via"
WwwAuthenticate = "Www-Authenticate" WwwAuthenticate = "Www-Authenticate"
XAmznRequestContextHeader = "x-amzn-request-context"
XContentTypeOptions = "X-Content-Type-Options" XContentTypeOptions = "X-Content-Type-Options"
XForwardedFor = "X-Forwarded-For" XForwardedFor = "X-Forwarded-For"
XForwardedHost = "X-Forwarded-Host" XForwardedHost = "X-Forwarded-Host"
@@ -63,6 +65,8 @@ const (
XFrameOptions = "X-Frame-Options" XFrameOptions = "X-Frame-Options"
XOriginWidth = "X-Origin-Width" XOriginWidth = "X-Origin-Width"
XOriginHeight = "X-Origin-Height" XOriginHeight = "X-Origin-Height"
XRealIP = "X-Real-IP"
XRequestID = "X-Request-ID"
XResultWidth = "X-Result-Width" XResultWidth = "X-Result-Width"
XResultHeight = "X-Result-Height" XResultHeight = "X-Result-Height"
XOriginContentLength = "X-Origin-Content-Length" XOriginContentLength = "X-Origin-Content-Length"

View File

@@ -7,6 +7,10 @@ import (
"strings" "strings"
) )
const (
defaultCategory = "default"
)
type Option func(*Error) type Option func(*Error)
type Error struct { type Error struct {
@@ -16,6 +20,7 @@ type Error struct {
statusCode int statusCode int
publicMessage string publicMessage string
shouldReport bool shouldReport bool
category string
stack []uintptr stack []uintptr
} }
@@ -64,6 +69,14 @@ func (e *Error) Callers() []uintptr {
return e.stack return e.stack
} }
func (e *Error) Category() string {
if e.category == "" {
return defaultCategory
}
return e.category
}
func (e *Error) FormatStackLines() []string { func (e *Error) FormatStackLines() []string {
lines := make([]string, len(e.stack)) lines := make([]string, len(e.stack))
@@ -141,6 +154,12 @@ func WithShouldReport(report bool) Option {
} }
} }
func WithCategory(category string) Option {
return func(e *Error) {
e.category = category
}
}
func callers(skip int) []uintptr { func callers(skip int) []uintptr {
stack := make([]uintptr, 10) stack := make([]uintptr, 10)
n := runtime.Callers(skip+2, stack) n := runtime.Callers(skip+2, stack)

44
main.go
View File

@@ -16,6 +16,7 @@ import (
"github.com/imgproxy/imgproxy/v3/config/loadenv" "github.com/imgproxy/imgproxy/v3/config/loadenv"
"github.com/imgproxy/imgproxy/v3/errorreport" "github.com/imgproxy/imgproxy/v3/errorreport"
"github.com/imgproxy/imgproxy/v3/gliblog" "github.com/imgproxy/imgproxy/v3/gliblog"
"github.com/imgproxy/imgproxy/v3/handlers"
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/logger" "github.com/imgproxy/imgproxy/v3/logger"
"github.com/imgproxy/imgproxy/v3/memory" "github.com/imgproxy/imgproxy/v3/memory"
@@ -23,10 +24,37 @@ import (
"github.com/imgproxy/imgproxy/v3/metrics/prometheus" "github.com/imgproxy/imgproxy/v3/metrics/prometheus"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/processing" "github.com/imgproxy/imgproxy/v3/processing"
"github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/version" "github.com/imgproxy/imgproxy/v3/version"
"github.com/imgproxy/imgproxy/v3/vips" "github.com/imgproxy/imgproxy/v3/vips"
) )
const (
faviconPath = "/favicon.ico"
healthPath = "/health"
)
func buildRouter(r *server.Router) *server.Router {
r.GET("/", true, handlers.LandingHandler)
r.GET("", true, handlers.LandingHandler)
r.GET(
"/", false, handleProcessing,
r.WithSecret, r.WithCORS, r.WithPanic, r.WithReportError, r.WithMetrics,
)
r.HEAD("/", false, r.OkHandler, r.WithCORS)
r.OPTIONS("/", false, r.OkHandler, r.WithCORS)
r.GET(faviconPath, true, r.NotFoundHandler).Silent()
r.GET(healthPath, true, handlers.HealthHandler).Silent()
if config.HealthCheckPath != "" {
r.GET(config.HealthCheckPath, true, handlers.HealthHandler).Silent()
}
return r
}
func initialize() error { func initialize() error {
if err := loadenv.Load(); err != nil { if err := loadenv.Load(); err != nil {
return err return err
@@ -103,25 +131,21 @@ func run(ctx context.Context) error {
} }
}() }()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
if err := prometheus.StartServer(cancel); err != nil { if err := prometheus.StartServer(cancel); err != nil {
return err return err
} }
s, err := startServer(cancel) cfg := server.NewConfigFromEnv()
r := server.NewRouter(cfg)
s, err := server.Start(cancel, buildRouter(r))
if err != nil { if err != nil {
return err return err
} }
defer shutdownServer(s) defer s.Shutdown(ctx)
stop := make(chan os.Signal, 1) <-ctx.Done()
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
select {
case <-ctx.Done():
case <-stop:
}
return nil return nil
} }

View File

@@ -161,7 +161,7 @@ func StartServer(cancel context.CancelFunc) error {
s := http.Server{Handler: promhttp.Handler()} s := http.Server{Handler: promhttp.Handler()}
l, err := reuseport.Listen("tcp", config.PrometheusBind) l, err := reuseport.Listen("tcp", config.PrometheusBind, config.SoReuseport)
if err != nil { if err != nil {
return fmt.Errorf("Can't start Prometheus metrics server: %s", err) return fmt.Errorf("Can't start Prometheus metrics server: %s", err)
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/imagetype"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/router" "github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/vips" "github.com/imgproxy/imgproxy/v3/vips"
) )
@@ -89,7 +89,7 @@ func (p pipeline) Run(ctx context.Context, img *vips.Image, po *options.Processi
return err return err
} }
if err := router.CheckTimeout(ctx); err != nil { if err := server.CheckTimeout(ctx); err != nil {
return err return err
} }
} }

View File

@@ -12,8 +12,8 @@ import (
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/imagetype"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/router"
"github.com/imgproxy/imgproxy/v3/security" "github.com/imgproxy/imgproxy/v3/security"
"github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/svg" "github.com/imgproxy/imgproxy/v3/svg"
"github.com/imgproxy/imgproxy/v3/vips" "github.com/imgproxy/imgproxy/v3/vips"
) )
@@ -173,7 +173,7 @@ func transformAnimated(ctx context.Context, img *vips.Image, po *options.Process
return err return err
} }
if err = router.CheckTimeout(ctx); err != nil { if err = server.CheckTimeout(ctx); err != nil {
return err return err
} }
} }
@@ -240,7 +240,7 @@ func saveImageToFitBytes(ctx context.Context, po *options.ProcessingOptions, img
} }
imgdata.Close() imgdata.Close()
if err := router.CheckTimeout(ctx); err != nil { if err := server.CheckTimeout(ctx); err != nil {
return nil, err return nil, err
} }

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -27,8 +26,8 @@ import (
"github.com/imgproxy/imgproxy/v3/metrics/stats" "github.com/imgproxy/imgproxy/v3/metrics/stats"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/processing" "github.com/imgproxy/imgproxy/v3/processing"
"github.com/imgproxy/imgproxy/v3/router"
"github.com/imgproxy/imgproxy/v3/security" "github.com/imgproxy/imgproxy/v3/security"
"github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/vips" "github.com/imgproxy/imgproxy/v3/vips"
) )
@@ -122,17 +121,19 @@ func setCanonical(rw http.ResponseWriter, originURL string) {
} }
} }
func writeOriginContentLengthDebugHeader(ctx context.Context, rw http.ResponseWriter, originData imagedata.ImageData) { func writeOriginContentLengthDebugHeader(rw http.ResponseWriter, originData imagedata.ImageData) error {
if !config.EnableDebugHeaders { if !config.EnableDebugHeaders {
return return nil
} }
size, err := originData.Size() size, err := originData.Size()
if err != nil { if err != nil {
checkErr(ctx, "image_data_size", err) return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
} }
rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size)) rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size))
return nil
} }
func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) { func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
@@ -146,13 +147,13 @@ func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight)) rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight))
} }
func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) { func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) error {
// We read the size of the image data here, so we can set Content-Length header. // We read the size of the image data here, so we can set Content-Length header.
// This indireclty ensures that the image data is fully read from the source, no // This indireclty ensures that the image data is fully read from the source, no
// errors happened. // errors happened.
resultSize, err := resultData.Size() resultSize, err := resultData.Size()
if err != nil { if err != nil {
checkErr(r.Context(), "image_data_size", err) return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
} }
contentDisposition := httpheaders.ContentDispositionValue( contentDisposition := httpheaders.ContentDispositionValue(
@@ -183,18 +184,19 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, sta
ierr = newResponseWriteError(err) ierr = newResponseWriteError(err)
if config.ReportIOErrors { if config.ReportIOErrors {
sendErr(r.Context(), "IO", ierr) return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true))
errorreport.Report(ierr, r)
} }
} }
router.LogResponse( server.LogResponse(
reqID, r, statusCode, ierr, reqID, r, statusCode, ierr,
log.Fields{ log.Fields{
"image_url": originURL, "image_url": originURL,
"processing_options": po, "processing_options": po,
}, },
) )
return nil
} }
func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) { func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) {
@@ -202,7 +204,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
setVary(rw) setVary(rw)
rw.WriteHeader(304) rw.WriteHeader(304)
router.LogResponse( server.LogResponse(
reqID, r, 304, nil, reqID, r, 304, nil,
log.Fields{ log.Fields{
"image_url": originURL, "image_url": originURL,
@@ -211,37 +213,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
) )
} }
func sendErr(ctx context.Context, errType string, err error) { func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error {
send := true
if ierr, ok := err.(*ierrors.Error); ok {
switch ierr.StatusCode() {
case http.StatusServiceUnavailable:
errType = "timeout"
case 499:
// Don't need to send a "request cancelled" error
send = false
}
}
if send {
metrics.SendError(ctx, errType, err)
}
}
func sendErrAndPanic(ctx context.Context, errType string, err error) {
sendErr(ctx, errType, err)
panic(err)
}
func checkErr(ctx context.Context, errType string, err error) {
if err == nil {
return
}
sendErrAndPanic(ctx, errType, err)
}
func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
stats.IncRequestsInProgress() stats.IncRequestsInProgress()
defer stats.DecRequestsInProgress() defer stats.DecRequestsInProgress()
@@ -263,19 +235,22 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
signature = path[:signatureEnd] signature = path[:signatureEnd]
path = path[signatureEnd:] path = path[signatureEnd:]
} else { } else {
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf( return ierrors.Wrap(
http.StatusNotFound, "Invalid path: %s", path), newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0,
ierrors.WithCategory(categoryPathParsing),
) )
} }
path = fixPath(path) path = fixPath(path)
if err := security.VerifySignature(signature, path); err != nil { if err := security.VerifySignature(signature, path); err != nil {
sendErrAndPanic(ctx, "security", err) return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
} }
po, imageURL, err := options.ParsePath(path, r.Header) po, imageURL, err := options.ParsePath(path, r.Header)
checkErr(ctx, "path_parsing", err) if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing))
}
var imageOrigin any var imageOrigin any
if u, uerr := url.Parse(imageURL); uerr == nil { if u, uerr := url.Parse(imageURL); uerr == nil {
@@ -295,19 +270,20 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
metrics.SetMetadata(ctx, metricsMeta) metrics.SetMetadata(ctx, metricsMeta)
err = security.VerifySourceURL(imageURL) err = security.VerifySourceURL(imageURL)
checkErr(ctx, "security", err) if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
}
if po.Raw { if po.Raw {
streamOriginImage(ctx, reqID, r, rw, po, imageURL) return streamOriginImage(ctx, reqID, r, rw, po, imageURL)
return
} }
// SVG is a special case. Though saving to svg is not supported, SVG->SVG is. // SVG is a special case. Though saving to svg is not supported, SVG->SVG is.
if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG { if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG {
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf( return ierrors.Wrap(newInvalidURLErrorf(
http.StatusUnprocessableEntity, http.StatusUnprocessableEntity,
"Resulting image format is not supported: %s", po.Format, "Resulting image format is not supported: %s", po.Format,
)) ), 0, ierrors.WithCategory(categoryPathParsing))
} }
imgRequestHeader := make(http.Header) imgRequestHeader := make(http.Header)
@@ -339,7 +315,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
} }
// The heavy part starts here, so we need to restrict worker number // The heavy part starts here, so we need to restrict worker number
func() { err = func() error {
defer metrics.StartQueueSegment(ctx)() defer metrics.StartQueueSegment(ctx)()
err = processingSem.Acquire(ctx, 1) err = processingSem.Acquire(ctx, 1)
@@ -347,12 +323,21 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
// We don't actually need to check timeout here, // We don't actually need to check timeout here,
// but it's an easy way to check if this is an actual timeout // but it's an easy way to check if this is an actual timeout
// or the request was canceled // or the request was canceled
checkErr(ctx, "queue", router.CheckTimeout(ctx)) if terr := server.CheckTimeout(ctx); terr != nil {
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
}
// We should never reach this line as err could be only ctx.Err() // We should never reach this line as err could be only ctx.Err()
// and we've already checked for it. But beter safe than sorry // and we've already checked for it. But beter safe than sorry
sendErrAndPanic(ctx, "queue", err)
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue))
} }
return nil
}() }()
if err != nil {
return err
}
defer processingSem.Release(1) defer processingSem.Release(1)
stats.IncImagesInProgress() stats.IncImagesInProgress()
@@ -375,7 +360,9 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
if config.CookiePassthrough { if config.CookiePassthrough {
downloadOpts.CookieJar, err = cookies.JarFromRequest(r) downloadOpts.CookieJar, err = cookies.JarFromRequest(r)
checkErr(ctx, "download", err) if err != nil {
return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload))
}
} }
return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts) return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts)
@@ -393,26 +380,28 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
} }
respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers()) respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers())
return return nil
default: default:
// This may be a request timeout error or a request cancelled error. // This may be a request timeout error or a request cancelled error.
// Check it before moving further // Check it before moving further
checkErr(ctx, "timeout", router.CheckTimeout(ctx)) if terr := server.CheckTimeout(ctx); terr != nil {
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
}
ierr := ierrors.Wrap(err, 0) ierr := ierrors.Wrap(err, 0)
if config.ReportDownloadingErrors { if config.ReportDownloadingErrors {
ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true)) ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true))
} }
sendErr(ctx, "download", ierr)
if imagedata.FallbackImage == nil { if imagedata.FallbackImage == nil {
panic(ierr) return ierr
} }
// We didn't panic, so the error is not reported. // Just send error
// Report it now metrics.SendError(ctx, categoryDownload, ierr)
// We didn't return, so we have to report error
if ierr.ShouldReport() { if ierr.ShouldReport() {
errorreport.Report(ierr, r) errorreport.Report(ierr, r)
} }
@@ -433,27 +422,33 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
} }
} }
checkErr(ctx, "timeout", router.CheckTimeout(ctx)) if terr := server.CheckTimeout(ctx); terr != nil {
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
}
if config.ETagEnabled && statusCode == http.StatusOK { if config.ETagEnabled && statusCode == http.StatusOK {
imgDataMatch, terr := etagHandler.SetActualImageData(originData, originHeaders) imgDataMatch, eerr := etagHandler.SetActualImageData(originData, originHeaders)
if terr == nil { if eerr != nil && config.ReportIOErrors {
rw.Header().Set("ETag", etagHandler.GenerateActualETag()) return ierrors.Wrap(eerr, 0, ierrors.WithCategory(categoryIO))
}
if imgDataMatch && etagHandler.ProcessingOptionsMatch() { rw.Header().Set("ETag", etagHandler.GenerateActualETag())
respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
return if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
} respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
return nil
} }
} }
checkErr(ctx, "timeout", router.CheckTimeout(ctx)) if terr := server.CheckTimeout(ctx); terr != nil {
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
}
if !vips.SupportsLoad(originData.Format()) { if !vips.SupportsLoad(originData.Format()) {
sendErrAndPanic(ctx, "processing", newInvalidURLErrorf( return ierrors.Wrap(newInvalidURLErrorf(
http.StatusUnprocessableEntity, http.StatusUnprocessableEntity,
"Source image format is not supported: %s", originData.Format(), "Source image format is not supported: %s", originData.Format(),
)) ), 0, ierrors.WithCategory(categoryProcessing))
} }
result, err := func() (*processing.Result, error) { result, err := func() (*processing.Result, error) {
@@ -468,18 +463,31 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
defer result.OutData.Close() defer result.OutData.Close()
} }
if err != nil { // First, check if the processing error wasn't caused by an image data error
// First, check if the processing error wasn't caused by an image data error if originData.Error() != nil {
checkErr(ctx, "download", originData.Error()) return ierrors.Wrap(originData.Error(), 0, ierrors.WithCategory(categoryDownload))
// If it wasn't, than it was a processing error
sendErrAndPanic(ctx, "processing", err)
} }
checkErr(ctx, "timeout", router.CheckTimeout(ctx)) // If it wasn't, than it was a processing error
if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing))
}
if terr := server.CheckTimeout(ctx); terr != nil {
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
}
writeDebugHeaders(rw, result) writeDebugHeaders(rw, result)
writeOriginContentLengthDebugHeader(ctx, rw, originData)
respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders) err = writeOriginContentLengthDebugHeader(rw, originData)
if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
}
err = respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders)
if err != nil {
return err
}
return nil
} }

View File

@@ -22,7 +22,7 @@ import (
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/imagetype"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/router" "github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/svg" "github.com/imgproxy/imgproxy/v3/svg"
"github.com/imgproxy/imgproxy/v3/testutil" "github.com/imgproxy/imgproxy/v3/testutil"
"github.com/imgproxy/imgproxy/v3/vips" "github.com/imgproxy/imgproxy/v3/vips"
@@ -31,7 +31,7 @@ import (
type ProcessingHandlerTestSuite struct { type ProcessingHandlerTestSuite struct {
suite.Suite suite.Suite
router *router.Router router *server.Router
} }
func (s *ProcessingHandlerTestSuite) SetupSuite() { func (s *ProcessingHandlerTestSuite) SetupSuite() {
@@ -48,7 +48,7 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() {
logrus.SetOutput(io.Discard) logrus.SetOutput(io.Discard)
s.router = buildRouter() s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv()))
} }
func (s *ProcessingHandlerTestSuite) TeardownSuite() { func (s *ProcessingHandlerTestSuite) TeardownSuite() {

View File

@@ -7,12 +7,10 @@ import (
"net" "net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/imgproxy/imgproxy/v3/config"
) )
func Listen(network, address string) (net.Listener, error) { func Listen(network, address string, reuse bool) (net.Listener, error) {
if config.SoReuseport { if reuse {
log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version") log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version")
} }

View File

@@ -10,12 +10,10 @@ import (
"syscall" "syscall"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/imgproxy/imgproxy/v3/config"
) )
func Listen(network, address string) (net.Listener, error) { func Listen(network, address string, reuse bool) (net.Listener, error) {
if !config.SoReuseport { if !reuse {
return net.Listen(network, address) return net.Listen(network, address)
} }

View File

@@ -1,174 +0,0 @@
package router
import (
"encoding/json"
"net"
"net/http"
"regexp"
"strings"
nanoid "github.com/matoous/go-nanoid/v2"
"github.com/imgproxy/imgproxy/v3/config"
)
const (
xRequestIDHeader = "X-Request-ID"
xAmznRequestContextHeader = "x-amzn-request-context"
)
var (
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
)
type RouteHandler func(string, http.ResponseWriter, *http.Request)
type route struct {
Method string
Prefix string
Handler RouteHandler
Exact bool
}
type Router struct {
prefix string
healthRoutes []string
faviconRoute string
Routes []*route
HealthHandler RouteHandler
}
func (r *route) isMatch(req *http.Request) bool {
if r.Method != req.Method {
return false
}
if r.Exact {
return req.URL.Path == r.Prefix
}
return strings.HasPrefix(req.URL.Path, r.Prefix)
}
func New(prefix string) *Router {
healthRoutes := []string{prefix + "/health"}
if len(config.HealthCheckPath) > 0 {
healthRoutes = append(healthRoutes, prefix+config.HealthCheckPath)
}
return &Router{
prefix: prefix,
healthRoutes: healthRoutes,
faviconRoute: prefix + "/favicon.ico",
Routes: make([]*route, 0),
}
}
func (r *Router) Add(method, prefix string, handler RouteHandler, exact bool) {
// Don't add routes with empty prefix
if len(r.prefix+prefix) == 0 {
return
}
r.Routes = append(
r.Routes,
&route{Method: method, Prefix: r.prefix + prefix, Handler: handler, Exact: exact},
)
}
func (r *Router) GET(prefix string, handler RouteHandler, exact bool) {
r.Add(http.MethodGet, prefix, handler, exact)
}
func (r *Router) OPTIONS(prefix string, handler RouteHandler, exact bool) {
r.Add(http.MethodOptions, prefix, handler, exact)
}
func (r *Router) HEAD(prefix string, handler RouteHandler, exact bool) {
r.Add(http.MethodHead, prefix, handler, exact)
}
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req, timeoutCancel := startRequestTimer(req)
defer timeoutCancel()
rw = newTimeoutResponse(rw)
reqID := req.Header.Get(xRequestIDHeader)
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
if lambdaContextVal := req.Header.Get(xAmznRequestContextHeader); len(lambdaContextVal) > 0 {
var lambdaContext struct {
RequestID string `json:"requestId"`
}
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
if err == nil && len(lambdaContext.RequestID) > 0 {
reqID = lambdaContext.RequestID
}
}
}
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
reqID, _ = nanoid.New()
}
rw.Header().Set("Server", "imgproxy")
rw.Header().Set(xRequestIDHeader, reqID)
if req.Method == http.MethodGet {
if r.HealthHandler != nil {
for _, healthRoute := range r.healthRoutes {
if req.URL.Path == healthRoute {
r.HealthHandler(reqID, rw, req)
return
}
}
}
if req.URL.Path == r.faviconRoute {
// TODO: Add a real favicon maybe?
rw.Header().Set("Content-Type", "text/plain")
rw.WriteHeader(404)
// Write a single byte to make AWS Lambda happy
rw.Write([]byte{' '})
return
}
}
if ip := req.Header.Get("CF-Connecting-IP"); len(ip) != 0 {
replaceRemoteAddr(req, ip)
} else if ip := req.Header.Get("X-Forwarded-For"); len(ip) != 0 {
if index := strings.Index(ip, ","); index > 0 {
ip = ip[:index]
}
replaceRemoteAddr(req, ip)
} else if ip := req.Header.Get("X-Real-IP"); len(ip) != 0 {
replaceRemoteAddr(req, ip)
}
LogRequest(reqID, req)
for _, rr := range r.Routes {
if rr.isMatch(req) {
rr.Handler(reqID, rw, req)
return
}
}
LogResponse(reqID, req, 404, newRouteNotDefinedError(req.URL.Path))
rw.Header().Set("Content-Type", "text/plain")
rw.WriteHeader(404)
rw.Write([]byte{' '})
}
func replaceRemoteAddr(req *http.Request, ip string) {
_, port, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
port = "80"
}
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
}

View File

@@ -1,43 +0,0 @@
package router
import (
"net/http"
"time"
"github.com/imgproxy/imgproxy/v3/config"
)
type timeoutResponse struct {
http.ResponseWriter
controller *http.ResponseController
}
func newTimeoutResponse(rw http.ResponseWriter) http.ResponseWriter {
return &timeoutResponse{
ResponseWriter: rw,
controller: http.NewResponseController(rw),
}
}
func (rw *timeoutResponse) WriteHeader(statusCode int) {
rw.withWriteDeadline(func() {
rw.ResponseWriter.WriteHeader(statusCode)
})
}
func (rw *timeoutResponse) Write(b []byte) (int, error) {
var (
n int
err error
)
rw.withWriteDeadline(func() {
n, err = rw.ResponseWriter.Write(b)
})
return n, err
}
func (rw *timeoutResponse) withWriteDeadline(f func()) {
rw.controller.SetWriteDeadline(time.Now().Add(time.Duration(config.WriteResponseTimeout) * time.Second))
defer rw.controller.SetWriteDeadline(time.Time{})
f()
}

204
server.go
View File

@@ -1,204 +0,0 @@
package main
import (
"context"
"crypto/subtle"
"fmt"
golog "log"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/netutil"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/errorreport"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/metrics"
"github.com/imgproxy/imgproxy/v3/reuseport"
"github.com/imgproxy/imgproxy/v3/router"
"github.com/imgproxy/imgproxy/v3/vips"
)
var imgproxyIsRunningMsg = []byte("imgproxy is running")
func buildRouter() *router.Router {
r := router.New(config.PathPrefix)
r.GET("/", handleLanding, true)
r.GET("", handleLanding, true)
r.GET("/", withMetrics(withPanicHandler(withCORS(withSecret(handleProcessing)))), false)
r.HEAD("/", withCORS(handleHead), false)
r.OPTIONS("/", withCORS(handleHead), false)
r.HealthHandler = handleHealth
return r
}
func startServer(cancel context.CancelFunc) (*http.Server, error) {
l, err := reuseport.Listen(config.Network, config.Bind)
if err != nil {
return nil, fmt.Errorf("can't start server: %s", err)
}
if config.MaxClients > 0 {
l = netutil.LimitListener(l, config.MaxClients)
}
errLogger := golog.New(
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
"", 0,
)
s := &http.Server{
Handler: buildRouter(),
ReadTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
MaxHeaderBytes: 1 << 20,
ErrorLog: errLogger,
}
if config.KeepAliveTimeout > 0 {
s.IdleTimeout = time.Duration(config.KeepAliveTimeout) * time.Second
} else {
s.SetKeepAlivesEnabled(false)
}
go func() {
log.Infof("Starting server at %s", config.Bind)
if err := s.Serve(l); err != nil && err != http.ErrServerClosed {
log.Error(err)
}
cancel()
}()
return s, nil
}
func shutdownServer(s *http.Server) {
log.Info("Shutting down the server...")
ctx, close := context.WithTimeout(context.Background(), 5*time.Second)
defer close()
s.Shutdown(ctx)
}
func withMetrics(h router.RouteHandler) router.RouteHandler {
if !metrics.Enabled() {
return h
}
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
ctx, metricsCancel, rw := metrics.StartRequest(r.Context(), rw, r)
defer metricsCancel()
h(reqID, rw, r.WithContext(ctx))
}
}
func withCORS(h router.RouteHandler) router.RouteHandler {
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
if len(config.AllowOrigin) > 0 {
rw.Header().Set("Access-Control-Allow-Origin", config.AllowOrigin)
rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
}
h(reqID, rw, r)
}
}
func withSecret(h router.RouteHandler) router.RouteHandler {
if len(config.Secret) == 0 {
return h
}
authHeader := []byte(fmt.Sprintf("Bearer %s", config.Secret))
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), authHeader) == 1 {
h(reqID, rw, r)
} else {
panic(newInvalidSecretError())
}
}
}
func withPanicHandler(h router.RouteHandler) router.RouteHandler {
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
ctx := errorreport.StartRequest(r)
r = r.WithContext(ctx)
errorreport.SetMetadata(r, "Request ID", reqID)
defer func() {
if rerr := recover(); rerr != nil {
if rerr == http.ErrAbortHandler {
panic(rerr)
}
err, ok := rerr.(error)
if !ok {
panic(rerr)
}
ierr := ierrors.Wrap(err, 0)
if ierr.ShouldReport() {
errorreport.Report(err, r)
}
router.LogResponse(reqID, r, ierr.StatusCode(), ierr)
rw.Header().Set("Content-Type", "text/plain")
rw.WriteHeader(ierr.StatusCode())
if config.DevelopmentErrorsMode {
rw.Write([]byte(ierr.Error()))
} else {
rw.Write([]byte(ierr.PublicMessage()))
}
}
}()
h(reqID, rw, r)
}
}
func handleHealth(reqID string, rw http.ResponseWriter, r *http.Request) {
var (
status int
msg []byte
ierr *ierrors.Error
)
if err := vips.Health(); err == nil {
status = http.StatusOK
msg = imgproxyIsRunningMsg
} else {
status = http.StatusInternalServerError
msg = []byte("Error")
ierr = ierrors.Wrap(err, 1)
}
if len(msg) == 0 {
msg = []byte{' '}
}
// Log response only if something went wrong
if ierr != nil {
router.LogResponse(reqID, r, status, ierr)
}
rw.Header().Set("Content-Type", "text/plain")
rw.Header().Set("Cache-Control", "no-cache")
rw.WriteHeader(status)
rw.Write(msg)
}
func handleHead(reqID string, rw http.ResponseWriter, r *http.Request) {
router.LogResponse(reqID, r, 200, nil)
rw.WriteHeader(200)
}

49
server/config.go Normal file
View File

@@ -0,0 +1,49 @@
package server
import (
"time"
"github.com/imgproxy/imgproxy/v3/config"
)
const (
// gracefulTimeout represents graceful shutdown timeout
gracefulTimeout = time.Duration(5 * time.Second)
)
// Config represents HTTP server config
type Config struct {
Listen string // Address to listen on
Network string // Network type (tcp, unix)
Bind string // Bind address
PathPrefix string // Path prefix for the server
MaxClients int // Maximum number of concurrent clients
ReadRequestTimeout time.Duration // Timeout for reading requests
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
GracefulTimeout time.Duration // Timeout for graceful shutdown
CORSAllowOrigin string // CORS allowed origin
Secret string // Secret for authorization
DevelopmentErrorsMode bool // Enable development mode for detailed error messages
SocketReusePort bool // Enable SO_REUSEPORT socket option
HealthCheckPath string // Health check path from config
WriteResponseTimeout time.Duration
}
// NewConfigFromEnv creates a new Config instance from environment variables
func NewConfigFromEnv() *Config {
return &Config{
Network: config.Network,
Bind: config.Bind,
PathPrefix: config.PathPrefix,
MaxClients: config.MaxClients,
ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second,
GracefulTimeout: gracefulTimeout,
CORSAllowOrigin: config.AllowOrigin,
Secret: config.Secret,
DevelopmentErrorsMode: config.DevelopmentErrorsMode,
SocketReusePort: config.SoReuseport,
HealthCheckPath: config.HealthCheckPath,
WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second,
}
}

View File

@@ -1,6 +1,7 @@
package router package server
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@@ -12,6 +13,7 @@ type (
RouteNotDefinedError string RouteNotDefinedError string
RequestCancelledError string RequestCancelledError string
RequestTimeoutError string RequestTimeoutError string
InvalidSecretError struct{}
) )
func newRouteNotDefinedError(path string) *ierrors.Error { func newRouteNotDefinedError(path string) *ierrors.Error {
@@ -33,11 +35,16 @@ func newRequestCancelledError(after time.Duration) *ierrors.Error {
ierrors.WithStatusCode(499), ierrors.WithStatusCode(499),
ierrors.WithPublicMessage("Cancelled"), ierrors.WithPublicMessage("Cancelled"),
ierrors.WithShouldReport(false), ierrors.WithShouldReport(false),
ierrors.WithCategory(categoryTimeout),
) )
} }
func (e RequestCancelledError) Error() string { return string(e) } func (e RequestCancelledError) Error() string { return string(e) }
func (e RequestCancelledError) Unwrap() error {
return context.Canceled
}
func newRequestTimeoutError(after time.Duration) *ierrors.Error { func newRequestTimeoutError(after time.Duration) *ierrors.Error {
return ierrors.Wrap( return ierrors.Wrap(
RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)), RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)),
@@ -45,7 +52,24 @@ func newRequestTimeoutError(after time.Duration) *ierrors.Error {
ierrors.WithStatusCode(http.StatusServiceUnavailable), ierrors.WithStatusCode(http.StatusServiceUnavailable),
ierrors.WithPublicMessage("Gateway Timeout"), ierrors.WithPublicMessage("Gateway Timeout"),
ierrors.WithShouldReport(false), ierrors.WithShouldReport(false),
ierrors.WithCategory(categoryTimeout),
) )
} }
func (e RequestTimeoutError) Error() string { return string(e) } func (e RequestTimeoutError) Error() string { return string(e) }
func (e RequestTimeoutError) Unwrap() error {
return context.DeadlineExceeded
}
func newInvalidSecretError() error {
return ierrors.Wrap(
InvalidSecretError{},
1,
ierrors.WithStatusCode(http.StatusForbidden),
ierrors.WithPublicMessage("Forbidden"),
ierrors.WithShouldReport(false),
)
}
func (e InvalidSecretError) Error() string { return "Invalid secret" }

View File

@@ -1,4 +1,4 @@
package router package server
import ( import (
"net" "net"
@@ -59,6 +59,6 @@ func LogResponse(reqID string, r *http.Request, status int, err *ierrors.Error,
log.WithFields(fields).Logf( log.WithFields(fields).Logf(
level, level,
"Completed in %s %s", ctxTime(r.Context()), r.RequestURI, "Completed in %s %s", requestStartedAt(r.Context()), r.RequestURI,
) )
} }

145
server/middlewares.go Normal file
View File

@@ -0,0 +1,145 @@
package server
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"net/http"
"github.com/imgproxy/imgproxy/v3/errorreport"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/metrics"
)
const (
categoryTimeout = "timeout"
)
// WithMetrics wraps RouteHandler with metrics handling.
func (r *Router) WithMetrics(h RouteHandler) RouteHandler {
if !metrics.Enabled() {
return h
}
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
ctx, metricsCancel, rw := metrics.StartRequest(req.Context(), rw, req)
defer metricsCancel()
return h(reqID, rw, req.WithContext(ctx))
}
}
// WithCORS wraps RouteHandler with CORS handling
func (r *Router) WithCORS(h RouteHandler) RouteHandler {
if len(r.config.CORSAllowOrigin) == 0 {
return h
}
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
return h(reqID, rw, req)
}
}
// WithSecret wraps RouteHandler with secret handling
func (r *Router) WithSecret(h RouteHandler) RouteHandler {
if len(r.config.Secret) == 0 {
return h
}
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
return h(reqID, rw, req)
} else {
return newInvalidSecretError()
}
}
}
// WithPanic recovers panic and converts it to normal error
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
defer func() {
// try to recover from panic
rerr := recover()
if rerr == nil {
return
}
// abort handler is an exception of net/http, we should simply repanic it.
// it will supress the stack trace
if rerr == http.ErrAbortHandler {
panic(rerr)
}
// let's recover error value from panic if it has panicked with error
err, ok := rerr.(error)
if !ok {
err = fmt.Errorf("panic: %v", err)
}
retErr = err
}()
return h(reqID, rw, r)
}
}
// WithReportError handles error reporting.
// It should be placed after `WithMetrics`, but before `WithPanic`.
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
// Open the error context
ctx := errorreport.StartRequest(req)
req = req.WithContext(ctx)
errorreport.SetMetadata(req, "Request ID", reqID)
// Call the underlying handler passing the context downwards
err := h(reqID, rw, req)
if err == nil {
return nil
}
// Wrap a resulting error into ierrors.Error
ierr := ierrors.Wrap(err, 0)
// Get the error category
errCat := ierr.Category()
// Exception: any context.DeadlineExceeded error is timeout
if errors.Is(ierr, context.DeadlineExceeded) {
errCat = categoryTimeout
}
// We do not need to send any canceled context
if !errors.Is(ierr, context.Canceled) {
metrics.SendError(ctx, errCat, err)
}
// Report error to error collectors
if ierr.ShouldReport() {
errorreport.Report(ierr, req)
}
// Log response and format the error output
LogResponse(reqID, req, ierr.StatusCode(), ierr)
// Error message: either is public message or full development error
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.WriteHeader(ierr.StatusCode())
if r.config.DevelopmentErrorsMode {
rw.Write([]byte(ierr.Error()))
} else {
rw.Write([]byte(ierr.PublicMessage()))
}
return nil
}
}

209
server/router.go Normal file
View File

@@ -0,0 +1,209 @@
package server
import (
"encoding/json"
"net"
"net/http"
"regexp"
"strings"
nanoid "github.com/matoous/go-nanoid/v2"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
const (
// defaultServerName is the default name of the server
defaultServerName = "imgproxy"
)
var (
// requestIDRe is a regular expression for validating request IDs
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
)
// RouteHandler is a function that handles HTTP requests.
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
// Middleware is a function that wraps a RouteHandler with additional functionality.
type Middleware func(next RouteHandler) RouteHandler
// route represents a single route in the router.
type route struct {
method string // method is the HTTP method for a route
path string // path represents a route path
exact bool // exact means that path must match exactly, otherwise any prefixed matches
handler RouteHandler // handler is the function that handles the route
silent bool // Silent route (no logs)
}
// Router is responsible for routing HTTP requests
type Router struct {
// config represents the server configuration
config *Config
// routes is the collection of all routes
routes []*route
}
// NewRouter creates a new Router instance
func NewRouter(config *Config) *Router {
return &Router{config: config}
}
// add adds an abitary route to the router
func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
for _, m := range middlewares {
handler = m(handler)
}
route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact}
r.routes = append(
r.routes,
route,
)
return route
}
// GET adds GET route
func (r *Router) GET(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
return r.add(http.MethodGet, prefix, exact, handler, middlewares...)
}
// OPTIONS adds OPTIONS route
func (r *Router) OPTIONS(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
return r.add(http.MethodOptions, prefix, exact, handler, middlewares...)
}
// HEAD adds HEAD route
func (r *Router) HEAD(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
return r.add(http.MethodHead, prefix, exact, handler, middlewares...)
}
// ServeHTTP serves routes
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Attach timer to the context
req, timeoutCancel := startRequestTimer(req)
defer timeoutCancel()
// Create the response writer which times out on write
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
// Get/create request ID
reqID := r.getRequestID(req)
// Replace request IP from headers
r.replaceRemoteAddr(req)
rw.Header().Set(httpheaders.Server, defaultServerName)
rw.Header().Set(httpheaders.XRequestID, reqID)
for _, rr := range r.routes {
if rr.isMatch(req) {
if !rr.silent {
LogRequest(reqID, req)
}
rr.handler(reqID, rw, req)
return
}
}
// Means that we have not found matching route
LogRequest(reqID, req)
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
r.NotFoundHandler(reqID, rw, req)
}
// NotFoundHandler is default 404 handler
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.WriteHeader(http.StatusNotFound)
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
return nil
}
// OkHandler is a default 200 OK handler
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.Header().Set(httpheaders.ContentType, "text/plain")
rw.WriteHeader(http.StatusOK)
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
return nil
}
// getRequestID tries to read request id from headers or from lambda
// context or generates a new one if nothing found.
func (r *Router) getRequestID(req *http.Request) string {
// Get request ID from headers (if any)
reqID := req.Header.Get(httpheaders.XRequestID)
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
lambdaContextVal := req.Header.Get(httpheaders.XAmznRequestContextHeader)
if len(lambdaContextVal) > 0 {
var lambdaContext struct {
RequestID string `json:"requestId"`
}
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
if err == nil && len(lambdaContext.RequestID) > 0 {
reqID = lambdaContext.RequestID
}
}
}
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
reqID, _ = nanoid.New()
}
return reqID
}
// replaceRemoteAddr rewrites the req.RemoteAddr property from request headers
func (r *Router) replaceRemoteAddr(req *http.Request) {
cfConnectingIP := req.Header.Get(httpheaders.CFConnectingIP)
xForwardedFor := req.Header.Get(httpheaders.XForwardedFor)
xRealIP := req.Header.Get(httpheaders.XRealIP)
switch {
case len(cfConnectingIP) > 0:
replaceRemoteAddr(req, cfConnectingIP)
case len(xForwardedFor) > 0:
if index := strings.Index(xForwardedFor, ","); index > 0 {
xForwardedFor = xForwardedFor[:index]
}
replaceRemoteAddr(req, xForwardedFor)
case len(xRealIP) > 0:
replaceRemoteAddr(req, xRealIP)
}
}
// replaceRemoteAddr sets the req.RemoteAddr for request
func replaceRemoteAddr(req *http.Request, ip string) {
_, port, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
port = "80"
}
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
}
// isMatch checks that a request matches route
func (r *route) isMatch(req *http.Request) bool {
methodMatches := r.method == req.Method
notExactPathMathes := !r.exact && strings.HasPrefix(req.URL.Path, r.path)
exactPathMatches := r.exact && req.URL.Path == r.path
return methodMatches && (notExactPathMathes || exactPathMatches)
}
// Silent sets Silent flag which supresses logs to true. We do not need to log
// requests like /health of /favicon.ico
func (r *route) Silent() *route {
r.silent = true
return r
}

296
server/router_test.go Normal file
View File

@@ -0,0 +1,296 @@
package server
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/httpheaders"
)
type RouterTestSuite struct {
suite.Suite
router *Router
}
func (s *RouterTestSuite) SetupTest() {
c := NewConfigFromEnv()
c.PathPrefix = "/api"
s.router = NewRouter(c)
}
func TestRouterSuite(t *testing.T) {
suite.Run(t, new(RouterTestSuite))
}
// TestHTTPMethods tests route methods registration and HTTP requests
func (s *RouterTestSuite) TestHTTPMethods() {
var capturedMethod string
var capturedPath string
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
rw.Write([]byte("GET response"))
return nil
}
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
rw.Write([]byte("OPTIONS response"))
return nil
}
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
capturedMethod = req.Method
capturedPath = req.URL.Path
rw.WriteHeader(200)
return nil
}
// Register routes with different configurations
s.router.GET("/get-test", true, getHandler) // exact match
s.router.OPTIONS("/options-test", false, optionsHandler) // prefix match
s.router.HEAD("/head-test", true, headHandler) // exact match
tests := []struct {
name string
requestMethod string
requestPath string
expectedBody string
expectedPath string
}{
{
name: "GET",
requestMethod: http.MethodGet,
requestPath: "/api/get-test",
expectedBody: "GET response",
expectedPath: "/api/get-test",
},
{
name: "OPTIONS",
requestMethod: http.MethodOptions,
requestPath: "/api/options-test",
expectedBody: "OPTIONS response",
expectedPath: "/api/options-test",
},
{
name: "OPTIONSPrefixed",
requestMethod: http.MethodOptions,
requestPath: "/api/options-test/sub",
expectedBody: "OPTIONS response",
expectedPath: "/api/options-test/sub",
},
{
name: "HEAD",
requestMethod: http.MethodHead,
requestPath: "/api/head-test",
expectedBody: "",
expectedPath: "/api/head-test",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
req := httptest.NewRequest(tt.requestMethod, tt.requestPath, nil)
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
s.Require().Equal(tt.expectedBody, rw.Body.String())
s.Require().Equal(tt.requestMethod, capturedMethod)
s.Require().Equal(tt.expectedPath, capturedPath)
})
}
}
// TestMiddlewareOrder checks middleware ordering and functionality
func (s *RouterTestSuite) TestMiddlewareOrder() {
var order []string
middleware1 := func(next RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
order = append(order, "middleware1")
return next(reqID, rw, req)
}
}
middleware2 := func(next RouteHandler) RouteHandler {
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
order = append(order, "middleware2")
return next(reqID, rw, req)
}
}
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
order = append(order, "handler")
rw.WriteHeader(200)
return nil
}
s.router.GET("/test", true, handler, middleware2, middleware1)
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
// Middleware should execute in the order they are passed (first added first)
s.Require().Equal([]string{"middleware1", "middleware2", "handler"}, order)
}
// TestServeHTTP tests ServeHTTP method
func (s *RouterTestSuite) TestServeHTTP() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.Header().Set("Custom-Header", "test-value")
rw.WriteHeader(200)
rw.Write([]byte("success"))
return nil
}
s.router.GET("/test", true, handler)
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
s.Require().Equal(200, rw.Code)
s.Require().Equal("success", rw.Body.String())
s.Require().Equal("test-value", rw.Header().Get("Custom-Header"))
s.Require().Equal(defaultServerName, rw.Header().Get(httpheaders.Server))
s.Require().NotEmpty(rw.Header().Get(httpheaders.XRequestID))
}
// TestRequestID checks request ID generation and validation
func (s *RouterTestSuite) TestRequestID() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.WriteHeader(200)
return nil
}
s.router.GET("/test", true, handler)
// Test request ID passthrough (if present)
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
req.Header.Set(httpheaders.XRequestID, "valid-id-123")
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
s.Require().Equal("valid-id-123", rw.Header().Get(httpheaders.XRequestID))
// Test invalid request ID (should generate a new one)
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
req.Header.Set(httpheaders.XRequestID, "invalid id with spaces!")
rw = httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
generatedID := rw.Header().Get(httpheaders.XRequestID)
s.Require().NotEqual("invalid id with spaces!", generatedID)
s.Require().NotEmpty(generatedID)
// Test no request ID (should generate a new one)
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
rw = httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
generatedID = rw.Header().Get(httpheaders.XRequestID)
s.Require().NotEmpty(generatedID)
s.Require().Regexp(`^[A-Za-z0-9_\-]+$`, generatedID)
}
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
rw.WriteHeader(200)
return nil
}
s.router.GET("/test", true, handler)
// Test with valid Lambda context
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
req.Header.Set(httpheaders.XAmznRequestContextHeader, `{"requestId":"lambda-req-123"}`)
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
s.Require().Equal("lambda-req-123", rw.Header().Get(httpheaders.XRequestID))
}
// Test IP address handling
func (s *RouterTestSuite) TestReplaceIP() {
var capturedRemoteAddr string
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
capturedRemoteAddr = req.RemoteAddr
rw.WriteHeader(200)
return nil
}
s.router.GET("/test", true, handler)
tests := []struct {
name string
originalAddr string
headers map[string]string
expectedAddr string
}{
{
name: "CFConnectingIP",
originalAddr: "original:8080",
headers: map[string]string{
httpheaders.CFConnectingIP: "1.2.3.4",
},
expectedAddr: "1.2.3.4:8080",
},
{
name: "XForwardedForMulti",
originalAddr: "original:8080",
headers: map[string]string{
httpheaders.XForwardedFor: "5.6.7.8, 9.10.11.12",
},
expectedAddr: "5.6.7.8:8080",
},
{
name: "XForwardedForSingle",
originalAddr: "original:8080",
headers: map[string]string{
httpheaders.XForwardedFor: "13.14.15.16",
},
expectedAddr: "13.14.15.16:8080",
},
{
name: "XRealIP",
originalAddr: "original:8080",
headers: map[string]string{
httpheaders.XRealIP: "17.18.19.20",
},
expectedAddr: "17.18.19.20:8080",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
req.RemoteAddr = tt.originalAddr
for header, value := range tt.headers {
req.Header.Set(header, value)
}
rw := httptest.NewRecorder()
s.router.ServeHTTP(rw, req)
s.Require().Equal(tt.expectedAddr, capturedRemoteAddr)
})
}
}

82
server/server.go Normal file
View File

@@ -0,0 +1,82 @@
package server
import (
"context"
"fmt"
golog "log"
"net/http"
log "github.com/sirupsen/logrus"
"golang.org/x/net/netutil"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/reuseport"
)
const (
// maxHeaderBytes represents max bytes in request header
maxHeaderBytes = 1 << 20
)
// Server represents the HTTP server wrapper struct
type Server struct {
router *Router
server *http.Server
}
// Start starts the http server. cancel is called in case server failed to start, but it happened
// asynchronously. It should cancel the upstream context.
func Start(cancel context.CancelFunc, router *Router) (*Server, error) {
l, err := reuseport.Listen(router.config.Network, router.config.Bind, router.config.SocketReusePort)
if err != nil {
cancel()
return nil, fmt.Errorf("can't start server: %s", err)
}
if router.config.MaxClients > 0 {
l = netutil.LimitListener(l, router.config.MaxClients)
}
errLogger := golog.New(
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
"", 0,
)
srv := &http.Server{
Handler: router,
ReadTimeout: router.config.ReadRequestTimeout,
MaxHeaderBytes: maxHeaderBytes,
ErrorLog: errLogger,
}
if config.KeepAliveTimeout > 0 {
srv.IdleTimeout = router.config.KeepAliveTimeout
} else {
srv.SetKeepAlivesEnabled(false)
}
go func() {
log.Infof("Starting server at %s", router.config.Bind)
if err := srv.Serve(l); err != nil && err != http.ErrServerClosed {
log.Error(err)
}
cancel()
}()
return &Server{
router: router,
server: srv,
}, nil
}
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) {
log.Info("Shutting down the server...")
ctx, close := context.WithTimeout(ctx, s.router.config.GracefulTimeout)
defer close()
s.server.Shutdown(ctx)
}

268
server/server_test.go Normal file
View File

@@ -0,0 +1,268 @@
package server
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/stretchr/testify/suite"
)
type ServerTestSuite struct {
suite.Suite
config *Config
blankRouter *Router
}
func (s *ServerTestSuite) SetupTest() {
config.Reset()
s.config = NewConfigFromEnv()
s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
s.blankRouter = NewRouter(s.config)
}
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
ctx, cancel := context.WithCancel(s.T().Context())
// Track if cancel was called using atomic
var cancelCalled atomic.Bool
cancelWrapper := func() {
cancel()
cancelCalled.Store(true)
}
invalidConfig := &Config{
Network: "tcp",
Bind: "invalid-address", // Invalid address
}
r := NewRouter(invalidConfig)
server, err := Start(cancelWrapper, r)
s.Require().Error(err)
s.Nil(server)
s.Contains(err.Error(), "can't start server")
// Check if cancel was called using Eventually
s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
// Also verify the context was cancelled
s.Require().Eventually(func() bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond)
}
func (s *ServerTestSuite) TestShutdown() {
_, cancel := context.WithCancel(context.Background())
defer cancel()
server, err := Start(cancel, s.blankRouter)
s.Require().NoError(err)
s.NotNil(server)
// Test graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
defer shutdownCancel()
// Should not panic or hang
s.NotPanics(func() {
server.Shutdown(shutdownCtx)
})
}
func (s *ServerTestSuite) TestWithCORS() {
tests := []struct {
name string
corsAllowOrigin string
expectedOrigin string
expectedMethods string
}{
{
name: "WithCORSOrigin",
corsAllowOrigin: "https://example.com",
expectedOrigin: "https://example.com",
expectedMethods: "GET, OPTIONS",
},
{
name: "NoCORSOrigin",
corsAllowOrigin: "",
expectedOrigin: "",
expectedMethods: "",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
config := &Config{
CORSAllowOrigin: tt.corsAllowOrigin,
}
router := NewRouter(config)
wrappedHandler := router.WithCORS(s.mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
})
}
}
func (s *ServerTestSuite) TestWithSecret() {
tests := []struct {
name string
secret string
authHeader string
expectError bool
}{
{
name: "ValidSecret",
secret: "test-secret",
authHeader: "Bearer test-secret",
},
{
name: "InvalidSecret",
secret: "foo-secret",
authHeader: "Bearer wrong-secret",
expectError: true,
},
{
name: "NoSecretConfigured",
secret: "",
authHeader: "",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
config := &Config{
Secret: tt.secret,
}
router := NewRouter(config)
wrappedHandler := router.WithSecret(s.mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
if tt.authHeader != "" {
req.Header.Set(httpheaders.Authorization, tt.authHeader)
}
rw := httptest.NewRecorder()
err := wrappedHandler("test-req-id", rw, req)
if tt.expectError {
s.Require().Error(err)
} else {
s.Require().NoError(err)
}
})
}
}
func (s *ServerTestSuite) TestIntoSuccess() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
rw.WriteHeader(http.StatusOK)
return nil
}
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
s.Equal(http.StatusOK, rw.Code)
}
func (s *ServerTestSuite) TestIntoWithError() {
testError := errors.New("test error")
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
return testError
}
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
wrappedHandler("test-req-id", rw, req)
s.Equal(http.StatusInternalServerError, rw.Code)
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
}
func (s *ServerTestSuite) TestIntoPanicWithError() {
testError := errors.New("panic error")
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
panic(testError)
}
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
s.NotPanics(func() {
err := wrappedHandler("test-req-id", rw, req)
s.Require().Error(err, "panic error")
})
s.Equal(http.StatusOK, rw.Code)
}
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
panic(http.ErrAbortHandler)
}
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Should re-panic with ErrAbortHandler
s.Panics(func() {
wrappedHandler("test-req-id", rw, req)
})
}
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
panic("string panic")
}
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Should re-panic with non-error panics
s.NotPanics(func() {
err := wrappedHandler("test-req-id", rw, req)
s.Require().Error(err, "string panic")
})
}
func TestServerTestSuite(t *testing.T) {
suite.Run(t, new(ServerTestSuite))
}

View File

@@ -0,0 +1,52 @@
package server
import (
"net/http"
"time"
)
// timeoutResponse manages response writer with timeout. It has
// timeout on all write methods.
type timeoutResponse struct {
http.ResponseWriter
controller *http.ResponseController
timeout time.Duration
}
// newTimeoutResponse creates a new timeoutResponse
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
return &timeoutResponse{
ResponseWriter: rw,
controller: http.NewResponseController(rw),
timeout: timeout,
}
}
// Write implements http.ResponseWriter.Write
func (rw *timeoutResponse) Write(b []byte) (int, error) {
var (
n int
err error
)
rw.withWriteDeadline(func() {
n, err = rw.ResponseWriter.Write(b)
})
return n, err
}
// Header returns current HTTP headers
func (rw *timeoutResponse) Header() http.Header {
return rw.ResponseWriter.Header()
}
// withWriteDeadline executes a Write* function with a deadline
func (rw *timeoutResponse) withWriteDeadline(f func()) {
deadline := time.Now().Add(rw.timeout)
// Set write deadline
rw.controller.SetWriteDeadline(deadline)
// Reset write deadline after method has finished
defer rw.controller.SetWriteDeadline(time.Time{})
f()
}

View File

@@ -1,4 +1,6 @@
package router // timer.go contains methods for storing, retrieving and checking
// timer in a request context.
package server
import ( import (
"context" "context"
@@ -9,8 +11,10 @@ import (
"github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/ierrors"
) )
// timerSinceCtxKey represents a context key for start time.
type timerSinceCtxKey struct{} type timerSinceCtxKey struct{}
// startRequestTimer starts a new request timer.
func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) { func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
ctx := r.Context() ctx := r.Context()
ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now()) ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now())
@@ -18,17 +22,20 @@ func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
return r.WithContext(ctx), cancel return r.WithContext(ctx), cancel
} }
func ctxTime(ctx context.Context) time.Duration { // requestStartedAt returns the duration since the timer started in the context.
func requestStartedAt(ctx context.Context) time.Duration {
if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok { if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok {
return time.Since(t) return time.Since(t)
} }
return 0 return 0
} }
// CheckTimeout checks if the request context has timed out or cancelled and returns
// wrapped error.
func CheckTimeout(ctx context.Context) error { func CheckTimeout(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
d := ctxTime(ctx) d := requestStartedAt(ctx)
err := ctx.Err() err := ctx.Err()
switch err { switch err {
@@ -37,7 +44,7 @@ func CheckTimeout(ctx context.Context) error {
case context.DeadlineExceeded: case context.DeadlineExceeded:
return newRequestTimeoutError(d) return newRequestTimeoutError(d)
default: default:
return ierrors.Wrap(err, 0) return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryTimeout))
} }
default: default:
return nil return nil

67
server/timer_test.go Normal file
View File

@@ -0,0 +1,67 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestCheckTimeout(t *testing.T) {
tests := []struct {
name string
setup func() context.Context
fail bool
}{
{
name: "WithoutTimeout",
setup: context.Background,
fail: false,
},
{
name: "ActiveTimerContext",
setup: func() context.Context {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
newReq, _ := startRequestTimer(req)
return newReq.Context()
},
fail: false,
},
{
name: "CancelledContext",
setup: func() context.Context {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
newReq, cancel := startRequestTimer(req)
cancel() // Cancel immediately
return newReq.Context()
},
fail: true,
},
{
name: "DeadlineExceeded",
setup: func() context.Context {
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
defer cancel()
time.Sleep(time.Millisecond * 10) // Ensure timeout
return ctx
},
fail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.setup()
err := CheckTimeout(ctx)
if tt.fail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/cookies" "github.com/imgproxy/imgproxy/v3/cookies"
"github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/ierrors"
"github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagedata"
"github.com/imgproxy/imgproxy/v3/metrics" "github.com/imgproxy/imgproxy/v3/metrics"
"github.com/imgproxy/imgproxy/v3/metrics/stats" "github.com/imgproxy/imgproxy/v3/metrics/stats"
"github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/options"
"github.com/imgproxy/imgproxy/v3/router" "github.com/imgproxy/imgproxy/v3/server"
) )
var ( var (
@@ -44,7 +45,7 @@ var (
} }
) )
func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) { func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error {
stats.IncImagesInProgress() stats.IncImagesInProgress()
defer stats.DecImagesInProgress() defer stats.DecImagesInProgress()
@@ -65,18 +66,24 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
if config.CookiePassthrough { if config.CookiePassthrough {
cookieJar, err = cookies.JarFromRequest(r) cookieJar, err = cookies.JarFromRequest(r)
checkErr(ctx, "streaming", err) if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
}
} }
req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar) req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
defer req.Cancel() defer req.Cancel()
checkErr(ctx, "streaming", err) if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
}
res, err := req.Send() res, err := req.Send()
if res != nil { if res != nil {
defer res.Body.Close() defer res.Body.Close()
} }
checkErr(ctx, "streaming", err) if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
}
for _, k := range streamRespHeaders { for _, k := range streamRespHeaders {
vv := res.Header.Values(k) vv := res.Header.Values(k)
@@ -116,7 +123,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
copyerr = nil copyerr = nil
} }
router.LogResponse( server.LogResponse(
reqID, r, res.StatusCode, nil, reqID, r, res.StatusCode, nil,
log.Fields{ log.Fields{
"image_url": imageURL, "image_url": imageURL,
@@ -127,4 +134,6 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
if copyerr != nil { if copyerr != nil {
panic(http.ErrAbortHandler) panic(http.ErrAbortHandler)
} }
return nil
} }