diff --git a/errors.go b/errors.go index 6e2b2c8f..9fd45702 100644 --- a/errors.go +++ b/errors.go @@ -7,11 +7,23 @@ import ( "github.com/imgproxy/imgproxy/v3/ierrors" ) +// Monitoring error categories +const ( + categoryTimeout = "timeout" + categoryImageDataSize = "image_data_size" + categoryPathParsing = "path_parsing" + categorySecurity = "security" + categoryQueue = "queue" + categoryDownload = "download" + categoryProcessing = "processing" + categoryIO = "IO" + categoryStreaming = "streaming" +) + type ( ResponseWriteError struct{ error } InvalidURLError string TooManyRequestsError struct{} - InvalidSecretError struct{} ) func newResponseWriteError(cause error) *ierrors.Error { @@ -53,15 +65,3 @@ func newTooManyRequestsError() error { } func (e TooManyRequestsError) Error() string { return "Too many requests" } - -func newInvalidSecretError() error { - return ierrors.Wrap( - InvalidSecretError{}, - 1, - ierrors.WithStatusCode(http.StatusForbidden), - ierrors.WithPublicMessage("Forbidden"), - ierrors.WithShouldReport(false), - ) -} - -func (e InvalidSecretError) Error() string { return "Invalid secret" } diff --git a/handlers/health.go b/handlers/health.go new file mode 100644 index 00000000..009a6f65 --- /dev/null +++ b/handlers/health.go @@ -0,0 +1,46 @@ +package handlers + +import ( + "net/http" + + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/server" + "github.com/imgproxy/imgproxy/v3/vips" +) + +var imgproxyIsRunningMsg = []byte("imgproxy is running") + +// HealthHandler handles the health check requests +func HealthHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { + var ( + status int + msg []byte + ierr *ierrors.Error + ) + + if err := vips.Health(); err == nil { + status = http.StatusOK + msg = imgproxyIsRunningMsg + } else { + status = http.StatusInternalServerError + msg = []byte("Error") + ierr = ierrors.Wrap(err, 1) + } + + if len(msg) == 0 { + msg = []byte{' '} + } + + // Log response only if something went wrong + if ierr != nil { + server.LogResponse(reqID, r, status, ierr) + } + + rw.Header().Set(httpheaders.ContentType, "text/plain") + rw.Header().Set(httpheaders.CacheControl, "no-cache") + rw.WriteHeader(status) + rw.Write(msg) + + return nil +} diff --git a/handlers/health_test.go b/handlers/health_test.go new file mode 100644 index 00000000..6b89fcf6 --- /dev/null +++ b/handlers/health_test.go @@ -0,0 +1,32 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/imgproxy/imgproxy/v3/httpheaders" +) + +func TestHealthHandler(t *testing.T) { + // Create a ResponseRecorder to record the response + rr := httptest.NewRecorder() + + // Call the handler function directly (no need for actual HTTP request) + HealthHandler("test-req-id", rr, nil) + + // Check that we get a valid response (either 200 or 500 depending on vips state) + assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError) + + // Check headers are set correctly + assert.Equal(t, "text/plain", rr.Header().Get(httpheaders.ContentType)) + assert.Equal(t, "no-cache", rr.Header().Get(httpheaders.CacheControl)) + + // Verify response format and content + body := rr.Body.String() + assert.NotEmpty(t, body) + + assert.Equal(t, "imgproxy is running", body) +} diff --git a/landing.go b/handlers/landing.go similarity index 94% rename from landing.go rename to handlers/landing.go index 3a1185c5..f4bd7fa5 100644 --- a/landing.go +++ b/handlers/landing.go @@ -1,6 +1,10 @@ -package main +package handlers -import "net/http" +import ( + "net/http" + + "github.com/imgproxy/imgproxy/v3/httpheaders" +) var landingTmpl = []byte(` @@ -39,8 +43,10 @@ var landingTmpl = []byte(` `) -func handleLanding(reqID string, rw http.ResponseWriter, r *http.Request) { - rw.Header().Set("Content-Type", "text/html") - rw.WriteHeader(200) +// LandingHandler handles the landing page requests +func LandingHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { + rw.Header().Set(httpheaders.ContentType, "text/html") + rw.WriteHeader(http.StatusOK) rw.Write(landingTmpl) + return nil } diff --git a/httpheaders/headers.go b/httpheaders/headers.go index a97658ee..fa43e840 100644 --- a/httpheaders/headers.go +++ b/httpheaders/headers.go @@ -17,6 +17,7 @@ const ( AltSvc = "Alt-Svc" Authorization = "Authorization" CacheControl = "Cache-Control" + CFConnectingIP = "CF-Connecting-IP" Connection = "Connection" ContentDisposition = "Content-Disposition" ContentEncoding = "Content-Encoding" @@ -56,6 +57,7 @@ const ( Vary = "Vary" Via = "Via" WwwAuthenticate = "Www-Authenticate" + XAmznRequestContextHeader = "x-amzn-request-context" XContentTypeOptions = "X-Content-Type-Options" XForwardedFor = "X-Forwarded-For" XForwardedHost = "X-Forwarded-Host" @@ -63,6 +65,8 @@ const ( XFrameOptions = "X-Frame-Options" XOriginWidth = "X-Origin-Width" XOriginHeight = "X-Origin-Height" + XRealIP = "X-Real-IP" + XRequestID = "X-Request-ID" XResultWidth = "X-Result-Width" XResultHeight = "X-Result-Height" XOriginContentLength = "X-Origin-Content-Length" diff --git a/ierrors/errors.go b/ierrors/errors.go index f5a84216..dc98b108 100644 --- a/ierrors/errors.go +++ b/ierrors/errors.go @@ -7,6 +7,10 @@ import ( "strings" ) +const ( + defaultCategory = "default" +) + type Option func(*Error) type Error struct { @@ -16,6 +20,7 @@ type Error struct { statusCode int publicMessage string shouldReport bool + category string stack []uintptr } @@ -64,6 +69,14 @@ func (e *Error) Callers() []uintptr { return e.stack } +func (e *Error) Category() string { + if e.category == "" { + return defaultCategory + } + + return e.category +} + func (e *Error) FormatStackLines() []string { lines := make([]string, len(e.stack)) @@ -141,6 +154,12 @@ func WithShouldReport(report bool) Option { } } +func WithCategory(category string) Option { + return func(e *Error) { + e.category = category + } +} + func callers(skip int) []uintptr { stack := make([]uintptr, 10) n := runtime.Callers(skip+2, stack) diff --git a/main.go b/main.go index 39e75316..3ee77e67 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "github.com/imgproxy/imgproxy/v3/config/loadenv" "github.com/imgproxy/imgproxy/v3/errorreport" "github.com/imgproxy/imgproxy/v3/gliblog" + "github.com/imgproxy/imgproxy/v3/handlers" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/logger" "github.com/imgproxy/imgproxy/v3/memory" @@ -23,10 +24,37 @@ import ( "github.com/imgproxy/imgproxy/v3/metrics/prometheus" "github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/processing" + "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/version" "github.com/imgproxy/imgproxy/v3/vips" ) +const ( + faviconPath = "/favicon.ico" + healthPath = "/health" +) + +func buildRouter(r *server.Router) *server.Router { + r.GET("/", true, handlers.LandingHandler) + r.GET("", true, handlers.LandingHandler) + + r.GET( + "/", false, handleProcessing, + r.WithSecret, r.WithCORS, r.WithPanic, r.WithReportError, r.WithMetrics, + ) + + r.HEAD("/", false, r.OkHandler, r.WithCORS) + r.OPTIONS("/", false, r.OkHandler, r.WithCORS) + + r.GET(faviconPath, true, r.NotFoundHandler).Silent() + r.GET(healthPath, true, handlers.HealthHandler).Silent() + if config.HealthCheckPath != "" { + r.GET(config.HealthCheckPath, true, handlers.HealthHandler).Silent() + } + + return r +} + func initialize() error { if err := loadenv.Load(); err != nil { return err @@ -103,25 +131,21 @@ func run(ctx context.Context) error { } }() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) if err := prometheus.StartServer(cancel); err != nil { return err } - s, err := startServer(cancel) + cfg := server.NewConfigFromEnv() + r := server.NewRouter(cfg) + s, err := server.Start(cancel, buildRouter(r)) if err != nil { return err } - defer shutdownServer(s) + defer s.Shutdown(ctx) - stop := make(chan os.Signal, 1) - signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) - - select { - case <-ctx.Done(): - case <-stop: - } + <-ctx.Done() return nil } diff --git a/metrics/prometheus/prometheus.go b/metrics/prometheus/prometheus.go index 5cecb66c..f4f7d297 100644 --- a/metrics/prometheus/prometheus.go +++ b/metrics/prometheus/prometheus.go @@ -161,7 +161,7 @@ func StartServer(cancel context.CancelFunc) error { s := http.Server{Handler: promhttp.Handler()} - l, err := reuseport.Listen("tcp", config.PrometheusBind) + l, err := reuseport.Listen("tcp", config.PrometheusBind, config.SoReuseport) if err != nil { return fmt.Errorf("Can't start Prometheus metrics server: %s", err) } diff --git a/processing/pipeline.go b/processing/pipeline.go index 3eac2264..ab9a0744 100644 --- a/processing/pipeline.go +++ b/processing/pipeline.go @@ -6,7 +6,7 @@ import ( "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/options" - "github.com/imgproxy/imgproxy/v3/router" + "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/vips" ) @@ -89,7 +89,7 @@ func (p pipeline) Run(ctx context.Context, img *vips.Image, po *options.Processi return err } - if err := router.CheckTimeout(ctx); err != nil { + if err := server.CheckTimeout(ctx); err != nil { return err } } diff --git a/processing/processing.go b/processing/processing.go index cab7d72f..16c16eeb 100644 --- a/processing/processing.go +++ b/processing/processing.go @@ -12,8 +12,8 @@ import ( "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/options" - "github.com/imgproxy/imgproxy/v3/router" "github.com/imgproxy/imgproxy/v3/security" + "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/svg" "github.com/imgproxy/imgproxy/v3/vips" ) @@ -173,7 +173,7 @@ func transformAnimated(ctx context.Context, img *vips.Image, po *options.Process return err } - if err = router.CheckTimeout(ctx); err != nil { + if err = server.CheckTimeout(ctx); err != nil { return err } } @@ -240,7 +240,7 @@ func saveImageToFitBytes(ctx context.Context, po *options.ProcessingOptions, img } imgdata.Close() - if err := router.CheckTimeout(ctx); err != nil { + if err := server.CheckTimeout(ctx); err != nil { return nil, err } diff --git a/processing_handler.go b/processing_handler.go index 371a64a9..0bd02822 100644 --- a/processing_handler.go +++ b/processing_handler.go @@ -1,7 +1,6 @@ package main import ( - "context" "errors" "fmt" "io" @@ -27,8 +26,8 @@ import ( "github.com/imgproxy/imgproxy/v3/metrics/stats" "github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/processing" - "github.com/imgproxy/imgproxy/v3/router" "github.com/imgproxy/imgproxy/v3/security" + "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/vips" ) @@ -122,17 +121,19 @@ func setCanonical(rw http.ResponseWriter, originURL string) { } } -func writeOriginContentLengthDebugHeader(ctx context.Context, rw http.ResponseWriter, originData imagedata.ImageData) { +func writeOriginContentLengthDebugHeader(rw http.ResponseWriter, originData imagedata.ImageData) error { if !config.EnableDebugHeaders { - return + return nil } size, err := originData.Size() if err != nil { - checkErr(ctx, "image_data_size", err) + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) } rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size)) + + return nil } func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) { @@ -146,13 +147,13 @@ func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) { rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight)) } -func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) { +func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) error { // We read the size of the image data here, so we can set Content-Length header. // This indireclty ensures that the image data is fully read from the source, no // errors happened. resultSize, err := resultData.Size() if err != nil { - checkErr(r.Context(), "image_data_size", err) + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) } contentDisposition := httpheaders.ContentDispositionValue( @@ -183,18 +184,19 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, sta ierr = newResponseWriteError(err) if config.ReportIOErrors { - sendErr(r.Context(), "IO", ierr) - errorreport.Report(ierr, r) + return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true)) } } - router.LogResponse( + server.LogResponse( reqID, r, statusCode, ierr, log.Fields{ "image_url": originURL, "processing_options": po, }, ) + + return nil } func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) { @@ -202,7 +204,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite setVary(rw) rw.WriteHeader(304) - router.LogResponse( + server.LogResponse( reqID, r, 304, nil, log.Fields{ "image_url": originURL, @@ -211,37 +213,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite ) } -func sendErr(ctx context.Context, errType string, err error) { - send := true - - if ierr, ok := err.(*ierrors.Error); ok { - switch ierr.StatusCode() { - case http.StatusServiceUnavailable: - errType = "timeout" - case 499: - // Don't need to send a "request cancelled" error - send = false - } - } - - if send { - metrics.SendError(ctx, errType, err) - } -} - -func sendErrAndPanic(ctx context.Context, errType string, err error) { - sendErr(ctx, errType, err) - panic(err) -} - -func checkErr(ctx context.Context, errType string, err error) { - if err == nil { - return - } - sendErrAndPanic(ctx, errType, err) -} - -func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { +func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error { stats.IncRequestsInProgress() defer stats.DecRequestsInProgress() @@ -263,19 +235,22 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { signature = path[:signatureEnd] path = path[signatureEnd:] } else { - sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf( - http.StatusNotFound, "Invalid path: %s", path), + return ierrors.Wrap( + newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0, + ierrors.WithCategory(categoryPathParsing), ) } path = fixPath(path) if err := security.VerifySignature(signature, path); err != nil { - sendErrAndPanic(ctx, "security", err) + return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) } po, imageURL, err := options.ParsePath(path, r.Header) - checkErr(ctx, "path_parsing", err) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing)) + } var imageOrigin any if u, uerr := url.Parse(imageURL); uerr == nil { @@ -295,19 +270,20 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { metrics.SetMetadata(ctx, metricsMeta) err = security.VerifySourceURL(imageURL) - checkErr(ctx, "security", err) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) + } if po.Raw { - streamOriginImage(ctx, reqID, r, rw, po, imageURL) - return + return streamOriginImage(ctx, reqID, r, rw, po, imageURL) } // SVG is a special case. Though saving to svg is not supported, SVG->SVG is. if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG { - sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf( + return ierrors.Wrap(newInvalidURLErrorf( http.StatusUnprocessableEntity, "Resulting image format is not supported: %s", po.Format, - )) + ), 0, ierrors.WithCategory(categoryPathParsing)) } imgRequestHeader := make(http.Header) @@ -339,7 +315,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { } // The heavy part starts here, so we need to restrict worker number - func() { + err = func() error { defer metrics.StartQueueSegment(ctx)() err = processingSem.Acquire(ctx, 1) @@ -347,12 +323,21 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { // We don't actually need to check timeout here, // but it's an easy way to check if this is an actual timeout // or the request was canceled - checkErr(ctx, "queue", router.CheckTimeout(ctx)) + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } + // We should never reach this line as err could be only ctx.Err() // and we've already checked for it. But beter safe than sorry - sendErrAndPanic(ctx, "queue", err) + + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue)) } + + return nil }() + if err != nil { + return err + } defer processingSem.Release(1) stats.IncImagesInProgress() @@ -375,7 +360,9 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { if config.CookiePassthrough { downloadOpts.CookieJar, err = cookies.JarFromRequest(r) - checkErr(ctx, "download", err) + if err != nil { + return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload)) + } } return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts) @@ -393,26 +380,28 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { } respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers()) - return + return nil default: // This may be a request timeout error or a request cancelled error. // Check it before moving further - checkErr(ctx, "timeout", router.CheckTimeout(ctx)) + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } ierr := ierrors.Wrap(err, 0) if config.ReportDownloadingErrors { ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true)) } - sendErr(ctx, "download", ierr) - if imagedata.FallbackImage == nil { - panic(ierr) + return ierr } - // We didn't panic, so the error is not reported. - // Report it now + // Just send error + metrics.SendError(ctx, categoryDownload, ierr) + + // We didn't return, so we have to report error if ierr.ShouldReport() { errorreport.Report(ierr, r) } @@ -433,27 +422,33 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { } } - checkErr(ctx, "timeout", router.CheckTimeout(ctx)) + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } if config.ETagEnabled && statusCode == http.StatusOK { - imgDataMatch, terr := etagHandler.SetActualImageData(originData, originHeaders) - if terr == nil { - rw.Header().Set("ETag", etagHandler.GenerateActualETag()) + imgDataMatch, eerr := etagHandler.SetActualImageData(originData, originHeaders) + if eerr != nil && config.ReportIOErrors { + return ierrors.Wrap(eerr, 0, ierrors.WithCategory(categoryIO)) + } - if imgDataMatch && etagHandler.ProcessingOptionsMatch() { - respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders) - return - } + rw.Header().Set("ETag", etagHandler.GenerateActualETag()) + + if imgDataMatch && etagHandler.ProcessingOptionsMatch() { + respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders) + return nil } } - checkErr(ctx, "timeout", router.CheckTimeout(ctx)) + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } if !vips.SupportsLoad(originData.Format()) { - sendErrAndPanic(ctx, "processing", newInvalidURLErrorf( + return ierrors.Wrap(newInvalidURLErrorf( http.StatusUnprocessableEntity, "Source image format is not supported: %s", originData.Format(), - )) + ), 0, ierrors.WithCategory(categoryProcessing)) } result, err := func() (*processing.Result, error) { @@ -468,18 +463,31 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { defer result.OutData.Close() } - if err != nil { - // First, check if the processing error wasn't caused by an image data error - checkErr(ctx, "download", originData.Error()) - - // If it wasn't, than it was a processing error - sendErrAndPanic(ctx, "processing", err) + // First, check if the processing error wasn't caused by an image data error + if originData.Error() != nil { + return ierrors.Wrap(originData.Error(), 0, ierrors.WithCategory(categoryDownload)) } - checkErr(ctx, "timeout", router.CheckTimeout(ctx)) + // If it wasn't, than it was a processing error + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing)) + } + + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } writeDebugHeaders(rw, result) - writeOriginContentLengthDebugHeader(ctx, rw, originData) - respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders) + err = writeOriginContentLengthDebugHeader(rw, originData) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) + } + + err = respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders) + if err != nil { + return err + } + + return nil } diff --git a/processing_handler_test.go b/processing_handler_test.go index f1f7208c..2b07270a 100644 --- a/processing_handler_test.go +++ b/processing_handler_test.go @@ -22,7 +22,7 @@ import ( "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/options" - "github.com/imgproxy/imgproxy/v3/router" + "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/svg" "github.com/imgproxy/imgproxy/v3/testutil" "github.com/imgproxy/imgproxy/v3/vips" @@ -31,7 +31,7 @@ import ( type ProcessingHandlerTestSuite struct { suite.Suite - router *router.Router + router *server.Router } func (s *ProcessingHandlerTestSuite) SetupSuite() { @@ -48,7 +48,7 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() { logrus.SetOutput(io.Discard) - s.router = buildRouter() + s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv())) } func (s *ProcessingHandlerTestSuite) TeardownSuite() { diff --git a/reuseport/listen_no_reuseport.go b/reuseport/listen_no_reuseport.go index 0548b39f..edcfcb71 100644 --- a/reuseport/listen_no_reuseport.go +++ b/reuseport/listen_no_reuseport.go @@ -7,12 +7,10 @@ import ( "net" log "github.com/sirupsen/logrus" - - "github.com/imgproxy/imgproxy/v3/config" ) -func Listen(network, address string) (net.Listener, error) { - if config.SoReuseport { +func Listen(network, address string, reuse bool) (net.Listener, error) { + if reuse { log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version") } diff --git a/reuseport/listen_reuseport.go b/reuseport/listen_reuseport.go index e649c276..6d00b1bb 100644 --- a/reuseport/listen_reuseport.go +++ b/reuseport/listen_reuseport.go @@ -10,12 +10,10 @@ import ( "syscall" "golang.org/x/sys/unix" - - "github.com/imgproxy/imgproxy/v3/config" ) -func Listen(network, address string) (net.Listener, error) { - if !config.SoReuseport { +func Listen(network, address string, reuse bool) (net.Listener, error) { + if !reuse { return net.Listen(network, address) } diff --git a/router/router.go b/router/router.go deleted file mode 100644 index c4f3109e..00000000 --- a/router/router.go +++ /dev/null @@ -1,174 +0,0 @@ -package router - -import ( - "encoding/json" - "net" - "net/http" - "regexp" - "strings" - - nanoid "github.com/matoous/go-nanoid/v2" - - "github.com/imgproxy/imgproxy/v3/config" -) - -const ( - xRequestIDHeader = "X-Request-ID" - xAmznRequestContextHeader = "x-amzn-request-context" -) - -var ( - requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`) -) - -type RouteHandler func(string, http.ResponseWriter, *http.Request) - -type route struct { - Method string - Prefix string - Handler RouteHandler - Exact bool -} - -type Router struct { - prefix string - healthRoutes []string - faviconRoute string - - Routes []*route - HealthHandler RouteHandler -} - -func (r *route) isMatch(req *http.Request) bool { - if r.Method != req.Method { - return false - } - - if r.Exact { - return req.URL.Path == r.Prefix - } - - return strings.HasPrefix(req.URL.Path, r.Prefix) -} - -func New(prefix string) *Router { - healthRoutes := []string{prefix + "/health"} - if len(config.HealthCheckPath) > 0 { - healthRoutes = append(healthRoutes, prefix+config.HealthCheckPath) - } - - return &Router{ - prefix: prefix, - healthRoutes: healthRoutes, - faviconRoute: prefix + "/favicon.ico", - Routes: make([]*route, 0), - } -} - -func (r *Router) Add(method, prefix string, handler RouteHandler, exact bool) { - // Don't add routes with empty prefix - if len(r.prefix+prefix) == 0 { - return - } - - r.Routes = append( - r.Routes, - &route{Method: method, Prefix: r.prefix + prefix, Handler: handler, Exact: exact}, - ) -} - -func (r *Router) GET(prefix string, handler RouteHandler, exact bool) { - r.Add(http.MethodGet, prefix, handler, exact) -} - -func (r *Router) OPTIONS(prefix string, handler RouteHandler, exact bool) { - r.Add(http.MethodOptions, prefix, handler, exact) -} - -func (r *Router) HEAD(prefix string, handler RouteHandler, exact bool) { - r.Add(http.MethodHead, prefix, handler, exact) -} - -func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - req, timeoutCancel := startRequestTimer(req) - defer timeoutCancel() - - rw = newTimeoutResponse(rw) - - reqID := req.Header.Get(xRequestIDHeader) - - if len(reqID) == 0 || !requestIDRe.MatchString(reqID) { - if lambdaContextVal := req.Header.Get(xAmznRequestContextHeader); len(lambdaContextVal) > 0 { - var lambdaContext struct { - RequestID string `json:"requestId"` - } - - err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext) - if err == nil && len(lambdaContext.RequestID) > 0 { - reqID = lambdaContext.RequestID - } - } - } - - if len(reqID) == 0 || !requestIDRe.MatchString(reqID) { - reqID, _ = nanoid.New() - } - - rw.Header().Set("Server", "imgproxy") - rw.Header().Set(xRequestIDHeader, reqID) - - if req.Method == http.MethodGet { - if r.HealthHandler != nil { - for _, healthRoute := range r.healthRoutes { - if req.URL.Path == healthRoute { - r.HealthHandler(reqID, rw, req) - return - } - } - } - - if req.URL.Path == r.faviconRoute { - // TODO: Add a real favicon maybe? - rw.Header().Set("Content-Type", "text/plain") - rw.WriteHeader(404) - // Write a single byte to make AWS Lambda happy - rw.Write([]byte{' '}) - return - } - } - - if ip := req.Header.Get("CF-Connecting-IP"); len(ip) != 0 { - replaceRemoteAddr(req, ip) - } else if ip := req.Header.Get("X-Forwarded-For"); len(ip) != 0 { - if index := strings.Index(ip, ","); index > 0 { - ip = ip[:index] - } - replaceRemoteAddr(req, ip) - } else if ip := req.Header.Get("X-Real-IP"); len(ip) != 0 { - replaceRemoteAddr(req, ip) - } - - LogRequest(reqID, req) - - for _, rr := range r.Routes { - if rr.isMatch(req) { - rr.Handler(reqID, rw, req) - return - } - } - - LogResponse(reqID, req, 404, newRouteNotDefinedError(req.URL.Path)) - - rw.Header().Set("Content-Type", "text/plain") - rw.WriteHeader(404) - rw.Write([]byte{' '}) -} - -func replaceRemoteAddr(req *http.Request, ip string) { - _, port, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - port = "80" - } - - req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port) -} diff --git a/router/timeout_response.go b/router/timeout_response.go deleted file mode 100644 index 1a1ca2c0..00000000 --- a/router/timeout_response.go +++ /dev/null @@ -1,43 +0,0 @@ -package router - -import ( - "net/http" - "time" - - "github.com/imgproxy/imgproxy/v3/config" -) - -type timeoutResponse struct { - http.ResponseWriter - controller *http.ResponseController -} - -func newTimeoutResponse(rw http.ResponseWriter) http.ResponseWriter { - return &timeoutResponse{ - ResponseWriter: rw, - controller: http.NewResponseController(rw), - } -} - -func (rw *timeoutResponse) WriteHeader(statusCode int) { - rw.withWriteDeadline(func() { - rw.ResponseWriter.WriteHeader(statusCode) - }) -} - -func (rw *timeoutResponse) Write(b []byte) (int, error) { - var ( - n int - err error - ) - rw.withWriteDeadline(func() { - n, err = rw.ResponseWriter.Write(b) - }) - return n, err -} - -func (rw *timeoutResponse) withWriteDeadline(f func()) { - rw.controller.SetWriteDeadline(time.Now().Add(time.Duration(config.WriteResponseTimeout) * time.Second)) - defer rw.controller.SetWriteDeadline(time.Time{}) - f() -} diff --git a/server.go b/server.go deleted file mode 100644 index eefd6c08..00000000 --- a/server.go +++ /dev/null @@ -1,204 +0,0 @@ -package main - -import ( - "context" - "crypto/subtle" - "fmt" - golog "log" - "net/http" - "time" - - log "github.com/sirupsen/logrus" - "golang.org/x/net/netutil" - - "github.com/imgproxy/imgproxy/v3/config" - "github.com/imgproxy/imgproxy/v3/errorreport" - "github.com/imgproxy/imgproxy/v3/ierrors" - "github.com/imgproxy/imgproxy/v3/metrics" - "github.com/imgproxy/imgproxy/v3/reuseport" - "github.com/imgproxy/imgproxy/v3/router" - "github.com/imgproxy/imgproxy/v3/vips" -) - -var imgproxyIsRunningMsg = []byte("imgproxy is running") - -func buildRouter() *router.Router { - r := router.New(config.PathPrefix) - - r.GET("/", handleLanding, true) - r.GET("", handleLanding, true) - - r.GET("/", withMetrics(withPanicHandler(withCORS(withSecret(handleProcessing)))), false) - - r.HEAD("/", withCORS(handleHead), false) - r.OPTIONS("/", withCORS(handleHead), false) - - r.HealthHandler = handleHealth - - return r -} - -func startServer(cancel context.CancelFunc) (*http.Server, error) { - l, err := reuseport.Listen(config.Network, config.Bind) - if err != nil { - return nil, fmt.Errorf("can't start server: %s", err) - } - - if config.MaxClients > 0 { - l = netutil.LimitListener(l, config.MaxClients) - } - - errLogger := golog.New( - log.WithField("source", "http_server").WriterLevel(log.ErrorLevel), - "", 0, - ) - - s := &http.Server{ - Handler: buildRouter(), - ReadTimeout: time.Duration(config.ReadRequestTimeout) * time.Second, - MaxHeaderBytes: 1 << 20, - ErrorLog: errLogger, - } - - if config.KeepAliveTimeout > 0 { - s.IdleTimeout = time.Duration(config.KeepAliveTimeout) * time.Second - } else { - s.SetKeepAlivesEnabled(false) - } - - go func() { - log.Infof("Starting server at %s", config.Bind) - if err := s.Serve(l); err != nil && err != http.ErrServerClosed { - log.Error(err) - } - cancel() - }() - - return s, nil -} - -func shutdownServer(s *http.Server) { - log.Info("Shutting down the server...") - - ctx, close := context.WithTimeout(context.Background(), 5*time.Second) - defer close() - - s.Shutdown(ctx) -} - -func withMetrics(h router.RouteHandler) router.RouteHandler { - if !metrics.Enabled() { - return h - } - - return func(reqID string, rw http.ResponseWriter, r *http.Request) { - ctx, metricsCancel, rw := metrics.StartRequest(r.Context(), rw, r) - defer metricsCancel() - - h(reqID, rw, r.WithContext(ctx)) - } -} - -func withCORS(h router.RouteHandler) router.RouteHandler { - return func(reqID string, rw http.ResponseWriter, r *http.Request) { - if len(config.AllowOrigin) > 0 { - rw.Header().Set("Access-Control-Allow-Origin", config.AllowOrigin) - rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - } - - h(reqID, rw, r) - } -} - -func withSecret(h router.RouteHandler) router.RouteHandler { - if len(config.Secret) == 0 { - return h - } - - authHeader := []byte(fmt.Sprintf("Bearer %s", config.Secret)) - - return func(reqID string, rw http.ResponseWriter, r *http.Request) { - if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), authHeader) == 1 { - h(reqID, rw, r) - } else { - panic(newInvalidSecretError()) - } - } -} - -func withPanicHandler(h router.RouteHandler) router.RouteHandler { - return func(reqID string, rw http.ResponseWriter, r *http.Request) { - ctx := errorreport.StartRequest(r) - r = r.WithContext(ctx) - - errorreport.SetMetadata(r, "Request ID", reqID) - - defer func() { - if rerr := recover(); rerr != nil { - if rerr == http.ErrAbortHandler { - panic(rerr) - } - - err, ok := rerr.(error) - if !ok { - panic(rerr) - } - - ierr := ierrors.Wrap(err, 0) - - if ierr.ShouldReport() { - errorreport.Report(err, r) - } - - router.LogResponse(reqID, r, ierr.StatusCode(), ierr) - - rw.Header().Set("Content-Type", "text/plain") - rw.WriteHeader(ierr.StatusCode()) - - if config.DevelopmentErrorsMode { - rw.Write([]byte(ierr.Error())) - } else { - rw.Write([]byte(ierr.PublicMessage())) - } - } - }() - - h(reqID, rw, r) - } -} - -func handleHealth(reqID string, rw http.ResponseWriter, r *http.Request) { - var ( - status int - msg []byte - ierr *ierrors.Error - ) - - if err := vips.Health(); err == nil { - status = http.StatusOK - msg = imgproxyIsRunningMsg - } else { - status = http.StatusInternalServerError - msg = []byte("Error") - ierr = ierrors.Wrap(err, 1) - } - - if len(msg) == 0 { - msg = []byte{' '} - } - - // Log response only if something went wrong - if ierr != nil { - router.LogResponse(reqID, r, status, ierr) - } - - rw.Header().Set("Content-Type", "text/plain") - rw.Header().Set("Cache-Control", "no-cache") - rw.WriteHeader(status) - rw.Write(msg) -} - -func handleHead(reqID string, rw http.ResponseWriter, r *http.Request) { - router.LogResponse(reqID, r, 200, nil) - rw.WriteHeader(200) -} diff --git a/server/config.go b/server/config.go new file mode 100644 index 00000000..8f9cef15 --- /dev/null +++ b/server/config.go @@ -0,0 +1,49 @@ +package server + +import ( + "time" + + "github.com/imgproxy/imgproxy/v3/config" +) + +const ( + // gracefulTimeout represents graceful shutdown timeout + gracefulTimeout = time.Duration(5 * time.Second) +) + +// Config represents HTTP server config +type Config struct { + Listen string // Address to listen on + Network string // Network type (tcp, unix) + Bind string // Bind address + PathPrefix string // Path prefix for the server + MaxClients int // Maximum number of concurrent clients + ReadRequestTimeout time.Duration // Timeout for reading requests + KeepAliveTimeout time.Duration // Timeout for keep-alive connections + GracefulTimeout time.Duration // Timeout for graceful shutdown + CORSAllowOrigin string // CORS allowed origin + Secret string // Secret for authorization + DevelopmentErrorsMode bool // Enable development mode for detailed error messages + SocketReusePort bool // Enable SO_REUSEPORT socket option + HealthCheckPath string // Health check path from config + WriteResponseTimeout time.Duration +} + +// NewConfigFromEnv creates a new Config instance from environment variables +func NewConfigFromEnv() *Config { + return &Config{ + Network: config.Network, + Bind: config.Bind, + PathPrefix: config.PathPrefix, + MaxClients: config.MaxClients, + ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second, + KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second, + GracefulTimeout: gracefulTimeout, + CORSAllowOrigin: config.AllowOrigin, + Secret: config.Secret, + DevelopmentErrorsMode: config.DevelopmentErrorsMode, + SocketReusePort: config.SoReuseport, + HealthCheckPath: config.HealthCheckPath, + WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second, + } +} diff --git a/router/errors.go b/server/errors.go similarity index 68% rename from router/errors.go rename to server/errors.go index c39e529e..8074d04e 100644 --- a/router/errors.go +++ b/server/errors.go @@ -1,6 +1,7 @@ -package router +package server import ( + "context" "fmt" "net/http" "time" @@ -12,6 +13,7 @@ type ( RouteNotDefinedError string RequestCancelledError string RequestTimeoutError string + InvalidSecretError struct{} ) func newRouteNotDefinedError(path string) *ierrors.Error { @@ -33,11 +35,16 @@ func newRequestCancelledError(after time.Duration) *ierrors.Error { ierrors.WithStatusCode(499), ierrors.WithPublicMessage("Cancelled"), ierrors.WithShouldReport(false), + ierrors.WithCategory(categoryTimeout), ) } func (e RequestCancelledError) Error() string { return string(e) } +func (e RequestCancelledError) Unwrap() error { + return context.Canceled +} + func newRequestTimeoutError(after time.Duration) *ierrors.Error { return ierrors.Wrap( RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)), @@ -45,7 +52,24 @@ func newRequestTimeoutError(after time.Duration) *ierrors.Error { ierrors.WithStatusCode(http.StatusServiceUnavailable), ierrors.WithPublicMessage("Gateway Timeout"), ierrors.WithShouldReport(false), + ierrors.WithCategory(categoryTimeout), ) } func (e RequestTimeoutError) Error() string { return string(e) } + +func (e RequestTimeoutError) Unwrap() error { + return context.DeadlineExceeded +} + +func newInvalidSecretError() error { + return ierrors.Wrap( + InvalidSecretError{}, + 1, + ierrors.WithStatusCode(http.StatusForbidden), + ierrors.WithPublicMessage("Forbidden"), + ierrors.WithShouldReport(false), + ) +} + +func (e InvalidSecretError) Error() string { return "Invalid secret" } diff --git a/router/logging.go b/server/logging.go similarity index 93% rename from router/logging.go rename to server/logging.go index f496739e..8721d5be 100644 --- a/router/logging.go +++ b/server/logging.go @@ -1,4 +1,4 @@ -package router +package server import ( "net" @@ -59,6 +59,6 @@ func LogResponse(reqID string, r *http.Request, status int, err *ierrors.Error, log.WithFields(fields).Logf( level, - "Completed in %s %s", ctxTime(r.Context()), r.RequestURI, + "Completed in %s %s", requestStartedAt(r.Context()), r.RequestURI, ) } diff --git a/server/middlewares.go b/server/middlewares.go new file mode 100644 index 00000000..2c9ca69e --- /dev/null +++ b/server/middlewares.go @@ -0,0 +1,145 @@ +package server + +import ( + "context" + "crypto/subtle" + "errors" + "fmt" + "net/http" + + "github.com/imgproxy/imgproxy/v3/errorreport" + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/metrics" +) + +const ( + categoryTimeout = "timeout" +) + +// WithMetrics wraps RouteHandler with metrics handling. +func (r *Router) WithMetrics(h RouteHandler) RouteHandler { + if !metrics.Enabled() { + return h + } + + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + ctx, metricsCancel, rw := metrics.StartRequest(req.Context(), rw, req) + defer metricsCancel() + + return h(reqID, rw, req.WithContext(ctx)) + } +} + +// WithCORS wraps RouteHandler with CORS handling +func (r *Router) WithCORS(h RouteHandler) RouteHandler { + if len(r.config.CORSAllowOrigin) == 0 { + return h + } + + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin) + rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS") + + return h(reqID, rw, req) + } +} + +// WithSecret wraps RouteHandler with secret handling +func (r *Router) WithSecret(h RouteHandler) RouteHandler { + if len(r.config.Secret) == 0 { + return h + } + + authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret) + + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 { + return h(reqID, rw, req) + } else { + return newInvalidSecretError() + } + } +} + +// WithPanic recovers panic and converts it to normal error +func (r *Router) WithPanic(h RouteHandler) RouteHandler { + return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) { + defer func() { + // try to recover from panic + rerr := recover() + if rerr == nil { + return + } + + // abort handler is an exception of net/http, we should simply repanic it. + // it will supress the stack trace + if rerr == http.ErrAbortHandler { + panic(rerr) + } + + // let's recover error value from panic if it has panicked with error + err, ok := rerr.(error) + if !ok { + err = fmt.Errorf("panic: %v", err) + } + + retErr = err + }() + + return h(reqID, rw, r) + } +} + +// WithReportError handles error reporting. +// It should be placed after `WithMetrics`, but before `WithPanic`. +func (r *Router) WithReportError(h RouteHandler) RouteHandler { + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + // Open the error context + ctx := errorreport.StartRequest(req) + req = req.WithContext(ctx) + errorreport.SetMetadata(req, "Request ID", reqID) + + // Call the underlying handler passing the context downwards + err := h(reqID, rw, req) + if err == nil { + return nil + } + + // Wrap a resulting error into ierrors.Error + ierr := ierrors.Wrap(err, 0) + + // Get the error category + errCat := ierr.Category() + + // Exception: any context.DeadlineExceeded error is timeout + if errors.Is(ierr, context.DeadlineExceeded) { + errCat = categoryTimeout + } + + // We do not need to send any canceled context + if !errors.Is(ierr, context.Canceled) { + metrics.SendError(ctx, errCat, err) + } + + // Report error to error collectors + if ierr.ShouldReport() { + errorreport.Report(ierr, req) + } + + // Log response and format the error output + LogResponse(reqID, req, ierr.StatusCode(), ierr) + + // Error message: either is public message or full development error + rw.Header().Set(httpheaders.ContentType, "text/plain") + rw.WriteHeader(ierr.StatusCode()) + + if r.config.DevelopmentErrorsMode { + rw.Write([]byte(ierr.Error())) + } else { + rw.Write([]byte(ierr.PublicMessage())) + } + + return nil + } +} diff --git a/server/router.go b/server/router.go new file mode 100644 index 00000000..c8e44d60 --- /dev/null +++ b/server/router.go @@ -0,0 +1,209 @@ +package server + +import ( + "encoding/json" + "net" + "net/http" + "regexp" + "strings" + + nanoid "github.com/matoous/go-nanoid/v2" + + "github.com/imgproxy/imgproxy/v3/httpheaders" +) + +const ( + // defaultServerName is the default name of the server + defaultServerName = "imgproxy" +) + +var ( + // requestIDRe is a regular expression for validating request IDs + requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`) +) + +// RouteHandler is a function that handles HTTP requests. +type RouteHandler func(string, http.ResponseWriter, *http.Request) error + +// Middleware is a function that wraps a RouteHandler with additional functionality. +type Middleware func(next RouteHandler) RouteHandler + +// route represents a single route in the router. +type route struct { + method string // method is the HTTP method for a route + path string // path represents a route path + exact bool // exact means that path must match exactly, otherwise any prefixed matches + handler RouteHandler // handler is the function that handles the route + silent bool // Silent route (no logs) +} + +// Router is responsible for routing HTTP requests +type Router struct { + // config represents the server configuration + config *Config + + // routes is the collection of all routes + routes []*route +} + +// NewRouter creates a new Router instance +func NewRouter(config *Config) *Router { + return &Router{config: config} +} + +// add adds an abitary route to the router +func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route { + for _, m := range middlewares { + handler = m(handler) + } + + route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact} + + r.routes = append( + r.routes, + route, + ) + + return route +} + +// GET adds GET route +func (r *Router) GET(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route { + return r.add(http.MethodGet, prefix, exact, handler, middlewares...) +} + +// OPTIONS adds OPTIONS route +func (r *Router) OPTIONS(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route { + return r.add(http.MethodOptions, prefix, exact, handler, middlewares...) +} + +// HEAD adds HEAD route +func (r *Router) HEAD(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route { + return r.add(http.MethodHead, prefix, exact, handler, middlewares...) +} + +// ServeHTTP serves routes +func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Attach timer to the context + req, timeoutCancel := startRequestTimer(req) + defer timeoutCancel() + + // Create the response writer which times out on write + rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout) + + // Get/create request ID + reqID := r.getRequestID(req) + + // Replace request IP from headers + r.replaceRemoteAddr(req) + + rw.Header().Set(httpheaders.Server, defaultServerName) + rw.Header().Set(httpheaders.XRequestID, reqID) + + for _, rr := range r.routes { + if rr.isMatch(req) { + if !rr.silent { + LogRequest(reqID, req) + } + + rr.handler(reqID, rw, req) + return + } + } + + // Means that we have not found matching route + LogRequest(reqID, req) + LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path)) + r.NotFoundHandler(reqID, rw, req) +} + +// NotFoundHandler is default 404 handler +func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.Header().Set(httpheaders.ContentType, "text/plain") + rw.WriteHeader(http.StatusNotFound) + rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy + + return nil +} + +// OkHandler is a default 200 OK handler +func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.Header().Set(httpheaders.ContentType, "text/plain") + rw.WriteHeader(http.StatusOK) + rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy + + return nil +} + +// getRequestID tries to read request id from headers or from lambda +// context or generates a new one if nothing found. +func (r *Router) getRequestID(req *http.Request) string { + // Get request ID from headers (if any) + reqID := req.Header.Get(httpheaders.XRequestID) + + if len(reqID) == 0 || !requestIDRe.MatchString(reqID) { + lambdaContextVal := req.Header.Get(httpheaders.XAmznRequestContextHeader) + + if len(lambdaContextVal) > 0 { + var lambdaContext struct { + RequestID string `json:"requestId"` + } + + err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext) + if err == nil && len(lambdaContext.RequestID) > 0 { + reqID = lambdaContext.RequestID + } + } + } + + if len(reqID) == 0 || !requestIDRe.MatchString(reqID) { + reqID, _ = nanoid.New() + } + + return reqID +} + +// replaceRemoteAddr rewrites the req.RemoteAddr property from request headers +func (r *Router) replaceRemoteAddr(req *http.Request) { + cfConnectingIP := req.Header.Get(httpheaders.CFConnectingIP) + xForwardedFor := req.Header.Get(httpheaders.XForwardedFor) + xRealIP := req.Header.Get(httpheaders.XRealIP) + + switch { + case len(cfConnectingIP) > 0: + replaceRemoteAddr(req, cfConnectingIP) + case len(xForwardedFor) > 0: + if index := strings.Index(xForwardedFor, ","); index > 0 { + xForwardedFor = xForwardedFor[:index] + } + replaceRemoteAddr(req, xForwardedFor) + case len(xRealIP) > 0: + replaceRemoteAddr(req, xRealIP) + } +} + +// replaceRemoteAddr sets the req.RemoteAddr for request +func replaceRemoteAddr(req *http.Request, ip string) { + _, port, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + port = "80" + } + + req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port) +} + +// isMatch checks that a request matches route +func (r *route) isMatch(req *http.Request) bool { + methodMatches := r.method == req.Method + notExactPathMathes := !r.exact && strings.HasPrefix(req.URL.Path, r.path) + exactPathMatches := r.exact && req.URL.Path == r.path + + return methodMatches && (notExactPathMathes || exactPathMatches) +} + +// Silent sets Silent flag which supresses logs to true. We do not need to log +// requests like /health of /favicon.ico +func (r *route) Silent() *route { + r.silent = true + return r +} diff --git a/server/router_test.go b/server/router_test.go new file mode 100644 index 00000000..4e704a1c --- /dev/null +++ b/server/router_test.go @@ -0,0 +1,296 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/imgproxy/imgproxy/v3/httpheaders" +) + +type RouterTestSuite struct { + suite.Suite + router *Router +} + +func (s *RouterTestSuite) SetupTest() { + c := NewConfigFromEnv() + c.PathPrefix = "/api" + s.router = NewRouter(c) +} + +func TestRouterSuite(t *testing.T) { + suite.Run(t, new(RouterTestSuite)) +} + +// TestHTTPMethods tests route methods registration and HTTP requests +func (s *RouterTestSuite) TestHTTPMethods() { + var capturedMethod string + var capturedPath string + + getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + capturedMethod = req.Method + capturedPath = req.URL.Path + rw.WriteHeader(200) + rw.Write([]byte("GET response")) + return nil + } + + optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + capturedMethod = req.Method + capturedPath = req.URL.Path + rw.WriteHeader(200) + rw.Write([]byte("OPTIONS response")) + return nil + } + + headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + capturedMethod = req.Method + capturedPath = req.URL.Path + rw.WriteHeader(200) + return nil + } + + // Register routes with different configurations + s.router.GET("/get-test", true, getHandler) // exact match + s.router.OPTIONS("/options-test", false, optionsHandler) // prefix match + s.router.HEAD("/head-test", true, headHandler) // exact match + + tests := []struct { + name string + requestMethod string + requestPath string + expectedBody string + expectedPath string + }{ + { + name: "GET", + requestMethod: http.MethodGet, + requestPath: "/api/get-test", + expectedBody: "GET response", + expectedPath: "/api/get-test", + }, + { + name: "OPTIONS", + requestMethod: http.MethodOptions, + requestPath: "/api/options-test", + expectedBody: "OPTIONS response", + expectedPath: "/api/options-test", + }, + { + name: "OPTIONSPrefixed", + requestMethod: http.MethodOptions, + requestPath: "/api/options-test/sub", + expectedBody: "OPTIONS response", + expectedPath: "/api/options-test/sub", + }, + { + name: "HEAD", + requestMethod: http.MethodHead, + requestPath: "/api/head-test", + expectedBody: "", + expectedPath: "/api/head-test", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest(tt.requestMethod, tt.requestPath, nil) + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + s.Require().Equal(tt.expectedBody, rw.Body.String()) + s.Require().Equal(tt.requestMethod, capturedMethod) + s.Require().Equal(tt.expectedPath, capturedPath) + }) + } +} + +// TestMiddlewareOrder checks middleware ordering and functionality +func (s *RouterTestSuite) TestMiddlewareOrder() { + var order []string + + middleware1 := func(next RouteHandler) RouteHandler { + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + order = append(order, "middleware1") + return next(reqID, rw, req) + } + } + + middleware2 := func(next RouteHandler) RouteHandler { + return func(reqID string, rw http.ResponseWriter, req *http.Request) error { + order = append(order, "middleware2") + return next(reqID, rw, req) + } + } + + handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + order = append(order, "handler") + rw.WriteHeader(200) + return nil + } + + s.router.GET("/test", true, handler, middleware2, middleware1) + + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + // Middleware should execute in the order they are passed (first added first) + s.Require().Equal([]string{"middleware1", "middleware2", "handler"}, order) +} + +// TestServeHTTP tests ServeHTTP method +func (s *RouterTestSuite) TestServeHTTP() { + handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.Header().Set("Custom-Header", "test-value") + rw.WriteHeader(200) + rw.Write([]byte("success")) + return nil + } + + s.router.GET("/test", true, handler) + + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + s.Require().Equal(200, rw.Code) + s.Require().Equal("success", rw.Body.String()) + s.Require().Equal("test-value", rw.Header().Get("Custom-Header")) + s.Require().Equal(defaultServerName, rw.Header().Get(httpheaders.Server)) + s.Require().NotEmpty(rw.Header().Get(httpheaders.XRequestID)) +} + +// TestRequestID checks request ID generation and validation +func (s *RouterTestSuite) TestRequestID() { + handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.WriteHeader(200) + return nil + } + + s.router.GET("/test", true, handler) + + // Test request ID passthrough (if present) + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.Header.Set(httpheaders.XRequestID, "valid-id-123") + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + s.Require().Equal("valid-id-123", rw.Header().Get(httpheaders.XRequestID)) + + // Test invalid request ID (should generate a new one) + req = httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.Header.Set(httpheaders.XRequestID, "invalid id with spaces!") + rw = httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + generatedID := rw.Header().Get(httpheaders.XRequestID) + s.Require().NotEqual("invalid id with spaces!", generatedID) + s.Require().NotEmpty(generatedID) + + // Test no request ID (should generate a new one) + req = httptest.NewRequest(http.MethodGet, "/api/test", nil) + rw = httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + generatedID = rw.Header().Get(httpheaders.XRequestID) + s.Require().NotEmpty(generatedID) + s.Require().Regexp(`^[A-Za-z0-9_\-]+$`, generatedID) +} + +// TestLambdaRequestIDExtraction checks AWS lambda request id extraction +func (s *RouterTestSuite) TestLambdaRequestIDExtraction() { + handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + rw.WriteHeader(200) + return nil + } + + s.router.GET("/test", true, handler) + + // Test with valid Lambda context + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.Header.Set(httpheaders.XAmznRequestContextHeader, `{"requestId":"lambda-req-123"}`) + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + s.Require().Equal("lambda-req-123", rw.Header().Get(httpheaders.XRequestID)) +} + +// Test IP address handling +func (s *RouterTestSuite) TestReplaceIP() { + var capturedRemoteAddr string + handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + capturedRemoteAddr = req.RemoteAddr + rw.WriteHeader(200) + return nil + } + + s.router.GET("/test", true, handler) + + tests := []struct { + name string + originalAddr string + headers map[string]string + expectedAddr string + }{ + { + name: "CFConnectingIP", + originalAddr: "original:8080", + headers: map[string]string{ + httpheaders.CFConnectingIP: "1.2.3.4", + }, + expectedAddr: "1.2.3.4:8080", + }, + { + name: "XForwardedForMulti", + originalAddr: "original:8080", + headers: map[string]string{ + httpheaders.XForwardedFor: "5.6.7.8, 9.10.11.12", + }, + expectedAddr: "5.6.7.8:8080", + }, + { + name: "XForwardedForSingle", + originalAddr: "original:8080", + headers: map[string]string{ + httpheaders.XForwardedFor: "13.14.15.16", + }, + expectedAddr: "13.14.15.16:8080", + }, + { + name: "XRealIP", + originalAddr: "original:8080", + headers: map[string]string{ + httpheaders.XRealIP: "17.18.19.20", + }, + expectedAddr: "17.18.19.20:8080", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.RemoteAddr = tt.originalAddr + + for header, value := range tt.headers { + req.Header.Set(header, value) + } + + rw := httptest.NewRecorder() + + s.router.ServeHTTP(rw, req) + + s.Require().Equal(tt.expectedAddr, capturedRemoteAddr) + }) + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 00000000..a1d6d1d6 --- /dev/null +++ b/server/server.go @@ -0,0 +1,82 @@ +package server + +import ( + "context" + "fmt" + golog "log" + "net/http" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/netutil" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/reuseport" +) + +const ( + // maxHeaderBytes represents max bytes in request header + maxHeaderBytes = 1 << 20 +) + +// Server represents the HTTP server wrapper struct +type Server struct { + router *Router + server *http.Server +} + +// Start starts the http server. cancel is called in case server failed to start, but it happened +// asynchronously. It should cancel the upstream context. +func Start(cancel context.CancelFunc, router *Router) (*Server, error) { + l, err := reuseport.Listen(router.config.Network, router.config.Bind, router.config.SocketReusePort) + if err != nil { + cancel() + return nil, fmt.Errorf("can't start server: %s", err) + } + + if router.config.MaxClients > 0 { + l = netutil.LimitListener(l, router.config.MaxClients) + } + + errLogger := golog.New( + log.WithField("source", "http_server").WriterLevel(log.ErrorLevel), + "", 0, + ) + + srv := &http.Server{ + Handler: router, + ReadTimeout: router.config.ReadRequestTimeout, + MaxHeaderBytes: maxHeaderBytes, + ErrorLog: errLogger, + } + + if config.KeepAliveTimeout > 0 { + srv.IdleTimeout = router.config.KeepAliveTimeout + } else { + srv.SetKeepAlivesEnabled(false) + } + + go func() { + log.Infof("Starting server at %s", router.config.Bind) + + if err := srv.Serve(l); err != nil && err != http.ErrServerClosed { + log.Error(err) + } + + cancel() + }() + + return &Server{ + router: router, + server: srv, + }, nil +} + +// Shutdown gracefully shuts down the server +func (s *Server) Shutdown(ctx context.Context) { + log.Info("Shutting down the server...") + + ctx, close := context.WithTimeout(ctx, s.router.config.GracefulTimeout) + defer close() + + s.server.Shutdown(ctx) +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 00000000..0b4b2989 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,268 @@ +package server + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/stretchr/testify/suite" +) + +type ServerTestSuite struct { + suite.Suite + config *Config + blankRouter *Router +} + +func (s *ServerTestSuite) SetupTest() { + config.Reset() + s.config = NewConfigFromEnv() + s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment + s.blankRouter = NewRouter(s.config) +} + +func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { + return nil +} + +func (s *ServerTestSuite) TestStartServerWithInvalidBind() { + ctx, cancel := context.WithCancel(s.T().Context()) + + // Track if cancel was called using atomic + var cancelCalled atomic.Bool + cancelWrapper := func() { + cancel() + cancelCalled.Store(true) + } + + invalidConfig := &Config{ + Network: "tcp", + Bind: "invalid-address", // Invalid address + } + + r := NewRouter(invalidConfig) + + server, err := Start(cancelWrapper, r) + + s.Require().Error(err) + s.Nil(server) + s.Contains(err.Error(), "can't start server") + + // Check if cancel was called using Eventually + s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond) + + // Also verify the context was cancelled + s.Require().Eventually(func() bool { + select { + case <-ctx.Done(): + return true + default: + return false + } + }, 100*time.Millisecond, 10*time.Millisecond) +} + +func (s *ServerTestSuite) TestShutdown() { + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + server, err := Start(cancel, s.blankRouter) + s.Require().NoError(err) + s.NotNil(server) + + // Test graceful shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second) + defer shutdownCancel() + + // Should not panic or hang + s.NotPanics(func() { + server.Shutdown(shutdownCtx) + }) +} + +func (s *ServerTestSuite) TestWithCORS() { + tests := []struct { + name string + corsAllowOrigin string + expectedOrigin string + expectedMethods string + }{ + { + name: "WithCORSOrigin", + corsAllowOrigin: "https://example.com", + expectedOrigin: "https://example.com", + expectedMethods: "GET, OPTIONS", + }, + { + name: "NoCORSOrigin", + corsAllowOrigin: "", + expectedOrigin: "", + expectedMethods: "", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + config := &Config{ + CORSAllowOrigin: tt.corsAllowOrigin, + } + router := NewRouter(config) + + wrappedHandler := router.WithCORS(s.mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + wrappedHandler("test-req-id", rw, req) + + s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin)) + s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods)) + }) + } +} + +func (s *ServerTestSuite) TestWithSecret() { + tests := []struct { + name string + secret string + authHeader string + expectError bool + }{ + { + name: "ValidSecret", + secret: "test-secret", + authHeader: "Bearer test-secret", + }, + { + name: "InvalidSecret", + secret: "foo-secret", + authHeader: "Bearer wrong-secret", + expectError: true, + }, + { + name: "NoSecretConfigured", + secret: "", + authHeader: "", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + config := &Config{ + Secret: tt.secret, + } + router := NewRouter(config) + + wrappedHandler := router.WithSecret(s.mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + if tt.authHeader != "" { + req.Header.Set(httpheaders.Authorization, tt.authHeader) + } + rw := httptest.NewRecorder() + + err := wrappedHandler("test-req-id", rw, req) + + if tt.expectError { + s.Require().Error(err) + } else { + s.Require().NoError(err) + } + }) + } +} + +func (s *ServerTestSuite) TestIntoSuccess() { + mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + rw.WriteHeader(http.StatusOK) + return nil + } + + wrappedHandler := s.blankRouter.WithReportError(mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + wrappedHandler("test-req-id", rw, req) + + s.Equal(http.StatusOK, rw.Code) +} + +func (s *ServerTestSuite) TestIntoWithError() { + testError := errors.New("test error") + mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + return testError + } + + wrappedHandler := s.blankRouter.WithReportError(mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + wrappedHandler("test-req-id", rw, req) + + s.Equal(http.StatusInternalServerError, rw.Code) + s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType)) +} + +func (s *ServerTestSuite) TestIntoPanicWithError() { + testError := errors.New("panic error") + mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + panic(testError) + } + + wrappedHandler := s.blankRouter.WithPanic(mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + s.NotPanics(func() { + err := wrappedHandler("test-req-id", rw, req) + s.Require().Error(err, "panic error") + }) + + s.Equal(http.StatusOK, rw.Code) +} + +func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() { + mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + panic(http.ErrAbortHandler) + } + + wrappedHandler := s.blankRouter.WithPanic(mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + // Should re-panic with ErrAbortHandler + s.Panics(func() { + wrappedHandler("test-req-id", rw, req) + }) +} + +func (s *ServerTestSuite) TestIntoPanicWithNonError() { + mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error { + panic("string panic") + } + + wrappedHandler := s.blankRouter.WithPanic(mockHandler) + + req := httptest.NewRequest("GET", "/test", nil) + rw := httptest.NewRecorder() + + // Should re-panic with non-error panics + s.NotPanics(func() { + err := wrappedHandler("test-req-id", rw, req) + s.Require().Error(err, "string panic") + }) +} + +func TestServerTestSuite(t *testing.T) { + suite.Run(t, new(ServerTestSuite)) +} diff --git a/server/timeout_response.go b/server/timeout_response.go new file mode 100644 index 00000000..43199131 --- /dev/null +++ b/server/timeout_response.go @@ -0,0 +1,52 @@ +package server + +import ( + "net/http" + "time" +) + +// timeoutResponse manages response writer with timeout. It has +// timeout on all write methods. +type timeoutResponse struct { + http.ResponseWriter + controller *http.ResponseController + timeout time.Duration +} + +// newTimeoutResponse creates a new timeoutResponse +func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter { + return &timeoutResponse{ + ResponseWriter: rw, + controller: http.NewResponseController(rw), + timeout: timeout, + } +} + +// Write implements http.ResponseWriter.Write +func (rw *timeoutResponse) Write(b []byte) (int, error) { + var ( + n int + err error + ) + rw.withWriteDeadline(func() { + n, err = rw.ResponseWriter.Write(b) + }) + return n, err +} + +// Header returns current HTTP headers +func (rw *timeoutResponse) Header() http.Header { + return rw.ResponseWriter.Header() +} + +// withWriteDeadline executes a Write* function with a deadline +func (rw *timeoutResponse) withWriteDeadline(f func()) { + deadline := time.Now().Add(rw.timeout) + + // Set write deadline + rw.controller.SetWriteDeadline(deadline) + + // Reset write deadline after method has finished + defer rw.controller.SetWriteDeadline(time.Time{}) + f() +} diff --git a/router/timer.go b/server/timer.go similarity index 59% rename from router/timer.go rename to server/timer.go index f3b135eb..8336a15d 100644 --- a/router/timer.go +++ b/server/timer.go @@ -1,4 +1,6 @@ -package router +// timer.go contains methods for storing, retrieving and checking +// timer in a request context. +package server import ( "context" @@ -9,8 +11,10 @@ import ( "github.com/imgproxy/imgproxy/v3/ierrors" ) +// timerSinceCtxKey represents a context key for start time. type timerSinceCtxKey struct{} +// startRequestTimer starts a new request timer. func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) { ctx := r.Context() ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now()) @@ -18,17 +22,20 @@ func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) { return r.WithContext(ctx), cancel } -func ctxTime(ctx context.Context) time.Duration { +// requestStartedAt returns the duration since the timer started in the context. +func requestStartedAt(ctx context.Context) time.Duration { if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok { return time.Since(t) } return 0 } +// CheckTimeout checks if the request context has timed out or cancelled and returns +// wrapped error. func CheckTimeout(ctx context.Context) error { select { case <-ctx.Done(): - d := ctxTime(ctx) + d := requestStartedAt(ctx) err := ctx.Err() switch err { @@ -37,7 +44,7 @@ func CheckTimeout(ctx context.Context) error { case context.DeadlineExceeded: return newRequestTimeoutError(d) default: - return ierrors.Wrap(err, 0) + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryTimeout)) } default: return nil diff --git a/server/timer_test.go b/server/timer_test.go new file mode 100644 index 00000000..b4b5b537 --- /dev/null +++ b/server/timer_test.go @@ -0,0 +1,67 @@ +package server + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCheckTimeout(t *testing.T) { + tests := []struct { + name string + setup func() context.Context + fail bool + }{ + { + name: "WithoutTimeout", + setup: context.Background, + fail: false, + }, + { + name: "ActiveTimerContext", + setup: func() context.Context { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + newReq, _ := startRequestTimer(req) + return newReq.Context() + }, + fail: false, + }, + { + name: "CancelledContext", + setup: func() context.Context { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + newReq, cancel := startRequestTimer(req) + cancel() // Cancel immediately + return newReq.Context() + }, + fail: true, + }, + { + name: "DeadlineExceeded", + setup: func() context.Context { + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(time.Millisecond * 10) // Ensure timeout + return ctx + }, + fail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setup() + err := CheckTimeout(ctx) + + if tt.fail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/stream.go b/stream.go index b4c88541..2bca15e6 100644 --- a/stream.go +++ b/stream.go @@ -12,11 +12,12 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/cookies" "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/metrics" "github.com/imgproxy/imgproxy/v3/metrics/stats" "github.com/imgproxy/imgproxy/v3/options" - "github.com/imgproxy/imgproxy/v3/router" + "github.com/imgproxy/imgproxy/v3/server" ) var ( @@ -44,7 +45,7 @@ var ( } ) -func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) { +func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error { stats.IncImagesInProgress() defer stats.DecImagesInProgress() @@ -65,18 +66,24 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht if config.CookiePassthrough { cookieJar, err = cookies.JarFromRequest(r) - checkErr(ctx, "streaming", err) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming)) + } } req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar) defer req.Cancel() - checkErr(ctx, "streaming", err) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming)) + } res, err := req.Send() if res != nil { defer res.Body.Close() } - checkErr(ctx, "streaming", err) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming)) + } for _, k := range streamRespHeaders { vv := res.Header.Values(k) @@ -116,7 +123,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht copyerr = nil } - router.LogResponse( + server.LogResponse( reqID, r, res.StatusCode, nil, log.Fields{ "image_url": imageURL, @@ -127,4 +134,6 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht if copyerr != nil { panic(http.ErrAbortHandler) } + + return nil }