mirror of
https://github.com/imgproxy/imgproxy.git
synced 2025-09-29 04:53:05 +02:00
IMG-51: router -> server ns, handlers ns, added error to handler ret val (#1494)
* Introduced server, handlers, error ret in handlerfn * Server struct with tests * replace checkErr with return
This commit is contained in:
26
errors.go
26
errors.go
@@ -7,11 +7,23 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
)
|
||||
|
||||
// Monitoring error categories
|
||||
const (
|
||||
categoryTimeout = "timeout"
|
||||
categoryImageDataSize = "image_data_size"
|
||||
categoryPathParsing = "path_parsing"
|
||||
categorySecurity = "security"
|
||||
categoryQueue = "queue"
|
||||
categoryDownload = "download"
|
||||
categoryProcessing = "processing"
|
||||
categoryIO = "IO"
|
||||
categoryStreaming = "streaming"
|
||||
)
|
||||
|
||||
type (
|
||||
ResponseWriteError struct{ error }
|
||||
InvalidURLError string
|
||||
TooManyRequestsError struct{}
|
||||
InvalidSecretError struct{}
|
||||
)
|
||||
|
||||
func newResponseWriteError(cause error) *ierrors.Error {
|
||||
@@ -53,15 +65,3 @@ func newTooManyRequestsError() error {
|
||||
}
|
||||
|
||||
func (e TooManyRequestsError) Error() string { return "Too many requests" }
|
||||
|
||||
func newInvalidSecretError() error {
|
||||
return ierrors.Wrap(
|
||||
InvalidSecretError{},
|
||||
1,
|
||||
ierrors.WithStatusCode(http.StatusForbidden),
|
||||
ierrors.WithPublicMessage("Forbidden"),
|
||||
ierrors.WithShouldReport(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (e InvalidSecretError) Error() string { return "Invalid secret" }
|
||||
|
46
handlers/health.go
Normal file
46
handlers/health.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
|
||||
var imgproxyIsRunningMsg = []byte("imgproxy is running")
|
||||
|
||||
// HealthHandler handles the health check requests
|
||||
func HealthHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
var (
|
||||
status int
|
||||
msg []byte
|
||||
ierr *ierrors.Error
|
||||
)
|
||||
|
||||
if err := vips.Health(); err == nil {
|
||||
status = http.StatusOK
|
||||
msg = imgproxyIsRunningMsg
|
||||
} else {
|
||||
status = http.StatusInternalServerError
|
||||
msg = []byte("Error")
|
||||
ierr = ierrors.Wrap(err, 1)
|
||||
}
|
||||
|
||||
if len(msg) == 0 {
|
||||
msg = []byte{' '}
|
||||
}
|
||||
|
||||
// Log response only if something went wrong
|
||||
if ierr != nil {
|
||||
server.LogResponse(reqID, r, status, ierr)
|
||||
}
|
||||
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.Header().Set(httpheaders.CacheControl, "no-cache")
|
||||
rw.WriteHeader(status)
|
||||
rw.Write(msg)
|
||||
|
||||
return nil
|
||||
}
|
32
handlers/health_test.go
Normal file
32
handlers/health_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
func TestHealthHandler(t *testing.T) {
|
||||
// Create a ResponseRecorder to record the response
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call the handler function directly (no need for actual HTTP request)
|
||||
HealthHandler("test-req-id", rr, nil)
|
||||
|
||||
// Check that we get a valid response (either 200 or 500 depending on vips state)
|
||||
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
|
||||
|
||||
// Check headers are set correctly
|
||||
assert.Equal(t, "text/plain", rr.Header().Get(httpheaders.ContentType))
|
||||
assert.Equal(t, "no-cache", rr.Header().Get(httpheaders.CacheControl))
|
||||
|
||||
// Verify response format and content
|
||||
body := rr.Body.String()
|
||||
assert.NotEmpty(t, body)
|
||||
|
||||
assert.Equal(t, "imgproxy is running", body)
|
||||
}
|
@@ -1,6 +1,10 @@
|
||||
package main
|
||||
package handlers
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
var landingTmpl = []byte(`
|
||||
<!doctype html>
|
||||
@@ -39,8 +43,10 @@ var landingTmpl = []byte(`
|
||||
</html>
|
||||
`)
|
||||
|
||||
func handleLanding(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(200)
|
||||
// LandingHandler handles the landing page requests
|
||||
func LandingHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/html")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write(landingTmpl)
|
||||
return nil
|
||||
}
|
@@ -17,6 +17,7 @@ const (
|
||||
AltSvc = "Alt-Svc"
|
||||
Authorization = "Authorization"
|
||||
CacheControl = "Cache-Control"
|
||||
CFConnectingIP = "CF-Connecting-IP"
|
||||
Connection = "Connection"
|
||||
ContentDisposition = "Content-Disposition"
|
||||
ContentEncoding = "Content-Encoding"
|
||||
@@ -56,6 +57,7 @@ const (
|
||||
Vary = "Vary"
|
||||
Via = "Via"
|
||||
WwwAuthenticate = "Www-Authenticate"
|
||||
XAmznRequestContextHeader = "x-amzn-request-context"
|
||||
XContentTypeOptions = "X-Content-Type-Options"
|
||||
XForwardedFor = "X-Forwarded-For"
|
||||
XForwardedHost = "X-Forwarded-Host"
|
||||
@@ -63,6 +65,8 @@ const (
|
||||
XFrameOptions = "X-Frame-Options"
|
||||
XOriginWidth = "X-Origin-Width"
|
||||
XOriginHeight = "X-Origin-Height"
|
||||
XRealIP = "X-Real-IP"
|
||||
XRequestID = "X-Request-ID"
|
||||
XResultWidth = "X-Result-Width"
|
||||
XResultHeight = "X-Result-Height"
|
||||
XOriginContentLength = "X-Origin-Content-Length"
|
||||
|
@@ -7,6 +7,10 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCategory = "default"
|
||||
)
|
||||
|
||||
type Option func(*Error)
|
||||
|
||||
type Error struct {
|
||||
@@ -16,6 +20,7 @@ type Error struct {
|
||||
statusCode int
|
||||
publicMessage string
|
||||
shouldReport bool
|
||||
category string
|
||||
|
||||
stack []uintptr
|
||||
}
|
||||
@@ -64,6 +69,14 @@ func (e *Error) Callers() []uintptr {
|
||||
return e.stack
|
||||
}
|
||||
|
||||
func (e *Error) Category() string {
|
||||
if e.category == "" {
|
||||
return defaultCategory
|
||||
}
|
||||
|
||||
return e.category
|
||||
}
|
||||
|
||||
func (e *Error) FormatStackLines() []string {
|
||||
lines := make([]string, len(e.stack))
|
||||
|
||||
@@ -141,6 +154,12 @@ func WithShouldReport(report bool) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithCategory(category string) Option {
|
||||
return func(e *Error) {
|
||||
e.category = category
|
||||
}
|
||||
}
|
||||
|
||||
func callers(skip int) []uintptr {
|
||||
stack := make([]uintptr, 10)
|
||||
n := runtime.Callers(skip+2, stack)
|
||||
|
44
main.go
44
main.go
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/config/loadenv"
|
||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||
"github.com/imgproxy/imgproxy/v3/gliblog"
|
||||
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/logger"
|
||||
"github.com/imgproxy/imgproxy/v3/memory"
|
||||
@@ -23,10 +24,37 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/metrics/prometheus"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/processing"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/version"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
|
||||
const (
|
||||
faviconPath = "/favicon.ico"
|
||||
healthPath = "/health"
|
||||
)
|
||||
|
||||
func buildRouter(r *server.Router) *server.Router {
|
||||
r.GET("/", true, handlers.LandingHandler)
|
||||
r.GET("", true, handlers.LandingHandler)
|
||||
|
||||
r.GET(
|
||||
"/", false, handleProcessing,
|
||||
r.WithSecret, r.WithCORS, r.WithPanic, r.WithReportError, r.WithMetrics,
|
||||
)
|
||||
|
||||
r.HEAD("/", false, r.OkHandler, r.WithCORS)
|
||||
r.OPTIONS("/", false, r.OkHandler, r.WithCORS)
|
||||
|
||||
r.GET(faviconPath, true, r.NotFoundHandler).Silent()
|
||||
r.GET(healthPath, true, handlers.HealthHandler).Silent()
|
||||
if config.HealthCheckPath != "" {
|
||||
r.GET(config.HealthCheckPath, true, handlers.HealthHandler).Silent()
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func initialize() error {
|
||||
if err := loadenv.Load(); err != nil {
|
||||
return err
|
||||
@@ -103,25 +131,21 @@ func run(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
if err := prometheus.StartServer(cancel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := startServer(cancel)
|
||||
cfg := server.NewConfigFromEnv()
|
||||
r := server.NewRouter(cfg)
|
||||
s, err := server.Start(cancel, buildRouter(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer shutdownServer(s)
|
||||
defer s.Shutdown(ctx)
|
||||
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-stop:
|
||||
}
|
||||
<-ctx.Done()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -161,7 +161,7 @@ func StartServer(cancel context.CancelFunc) error {
|
||||
|
||||
s := http.Server{Handler: promhttp.Handler()}
|
||||
|
||||
l, err := reuseport.Listen("tcp", config.PrometheusBind)
|
||||
l, err := reuseport.Listen("tcp", config.PrometheusBind, config.SoReuseport)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Can't start Prometheus metrics server: %s", err)
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ func (p pipeline) Run(ctx context.Context, img *vips.Image, po *options.Processi
|
||||
return err
|
||||
}
|
||||
|
||||
if err := router.CheckTimeout(ctx); err != nil {
|
||||
if err := server.CheckTimeout(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@@ -12,8 +12,8 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/security"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/svg"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
@@ -173,7 +173,7 @@ func transformAnimated(ctx context.Context, img *vips.Image, po *options.Process
|
||||
return err
|
||||
}
|
||||
|
||||
if err = router.CheckTimeout(ctx); err != nil {
|
||||
if err = server.CheckTimeout(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func saveImageToFitBytes(ctx context.Context, po *options.ProcessingOptions, img
|
||||
}
|
||||
imgdata.Close()
|
||||
|
||||
if err := router.CheckTimeout(ctx); err != nil {
|
||||
if err := server.CheckTimeout(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -27,8 +26,8 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/processing"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/security"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
|
||||
@@ -122,17 +121,19 @@ func setCanonical(rw http.ResponseWriter, originURL string) {
|
||||
}
|
||||
}
|
||||
|
||||
func writeOriginContentLengthDebugHeader(ctx context.Context, rw http.ResponseWriter, originData imagedata.ImageData) {
|
||||
func writeOriginContentLengthDebugHeader(rw http.ResponseWriter, originData imagedata.ImageData) error {
|
||||
if !config.EnableDebugHeaders {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
size, err := originData.Size()
|
||||
if err != nil {
|
||||
checkErr(ctx, "image_data_size", err)
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||
}
|
||||
|
||||
rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
|
||||
@@ -146,13 +147,13 @@ func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
|
||||
rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight))
|
||||
}
|
||||
|
||||
func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) {
|
||||
func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) error {
|
||||
// We read the size of the image data here, so we can set Content-Length header.
|
||||
// This indireclty ensures that the image data is fully read from the source, no
|
||||
// errors happened.
|
||||
resultSize, err := resultData.Size()
|
||||
if err != nil {
|
||||
checkErr(r.Context(), "image_data_size", err)
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||
}
|
||||
|
||||
contentDisposition := httpheaders.ContentDispositionValue(
|
||||
@@ -183,18 +184,19 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, sta
|
||||
ierr = newResponseWriteError(err)
|
||||
|
||||
if config.ReportIOErrors {
|
||||
sendErr(r.Context(), "IO", ierr)
|
||||
errorreport.Report(ierr, r)
|
||||
return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true))
|
||||
}
|
||||
}
|
||||
|
||||
router.LogResponse(
|
||||
server.LogResponse(
|
||||
reqID, r, statusCode, ierr,
|
||||
log.Fields{
|
||||
"image_url": originURL,
|
||||
"processing_options": po,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) {
|
||||
@@ -202,7 +204,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
|
||||
setVary(rw)
|
||||
|
||||
rw.WriteHeader(304)
|
||||
router.LogResponse(
|
||||
server.LogResponse(
|
||||
reqID, r, 304, nil,
|
||||
log.Fields{
|
||||
"image_url": originURL,
|
||||
@@ -211,37 +213,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
|
||||
)
|
||||
}
|
||||
|
||||
func sendErr(ctx context.Context, errType string, err error) {
|
||||
send := true
|
||||
|
||||
if ierr, ok := err.(*ierrors.Error); ok {
|
||||
switch ierr.StatusCode() {
|
||||
case http.StatusServiceUnavailable:
|
||||
errType = "timeout"
|
||||
case 499:
|
||||
// Don't need to send a "request cancelled" error
|
||||
send = false
|
||||
}
|
||||
}
|
||||
|
||||
if send {
|
||||
metrics.SendError(ctx, errType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendErrAndPanic(ctx context.Context, errType string, err error) {
|
||||
sendErr(ctx, errType, err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func checkErr(ctx context.Context, errType string, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
sendErrAndPanic(ctx, errType, err)
|
||||
}
|
||||
|
||||
func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
stats.IncRequestsInProgress()
|
||||
defer stats.DecRequestsInProgress()
|
||||
|
||||
@@ -263,19 +235,22 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
signature = path[:signatureEnd]
|
||||
path = path[signatureEnd:]
|
||||
} else {
|
||||
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf(
|
||||
http.StatusNotFound, "Invalid path: %s", path),
|
||||
return ierrors.Wrap(
|
||||
newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0,
|
||||
ierrors.WithCategory(categoryPathParsing),
|
||||
)
|
||||
}
|
||||
|
||||
path = fixPath(path)
|
||||
|
||||
if err := security.VerifySignature(signature, path); err != nil {
|
||||
sendErrAndPanic(ctx, "security", err)
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
|
||||
}
|
||||
|
||||
po, imageURL, err := options.ParsePath(path, r.Header)
|
||||
checkErr(ctx, "path_parsing", err)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing))
|
||||
}
|
||||
|
||||
var imageOrigin any
|
||||
if u, uerr := url.Parse(imageURL); uerr == nil {
|
||||
@@ -295,19 +270,20 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
metrics.SetMetadata(ctx, metricsMeta)
|
||||
|
||||
err = security.VerifySourceURL(imageURL)
|
||||
checkErr(ctx, "security", err)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
|
||||
}
|
||||
|
||||
if po.Raw {
|
||||
streamOriginImage(ctx, reqID, r, rw, po, imageURL)
|
||||
return
|
||||
return streamOriginImage(ctx, reqID, r, rw, po, imageURL)
|
||||
}
|
||||
|
||||
// SVG is a special case. Though saving to svg is not supported, SVG->SVG is.
|
||||
if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG {
|
||||
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf(
|
||||
return ierrors.Wrap(newInvalidURLErrorf(
|
||||
http.StatusUnprocessableEntity,
|
||||
"Resulting image format is not supported: %s", po.Format,
|
||||
))
|
||||
), 0, ierrors.WithCategory(categoryPathParsing))
|
||||
}
|
||||
|
||||
imgRequestHeader := make(http.Header)
|
||||
@@ -339,7 +315,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// The heavy part starts here, so we need to restrict worker number
|
||||
func() {
|
||||
err = func() error {
|
||||
defer metrics.StartQueueSegment(ctx)()
|
||||
|
||||
err = processingSem.Acquire(ctx, 1)
|
||||
@@ -347,12 +323,21 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
// We don't actually need to check timeout here,
|
||||
// but it's an easy way to check if this is an actual timeout
|
||||
// or the request was canceled
|
||||
checkErr(ctx, "queue", router.CheckTimeout(ctx))
|
||||
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
|
||||
// We should never reach this line as err could be only ctx.Err()
|
||||
// and we've already checked for it. But beter safe than sorry
|
||||
sendErrAndPanic(ctx, "queue", err)
|
||||
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue))
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer processingSem.Release(1)
|
||||
|
||||
stats.IncImagesInProgress()
|
||||
@@ -375,7 +360,9 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if config.CookiePassthrough {
|
||||
downloadOpts.CookieJar, err = cookies.JarFromRequest(r)
|
||||
checkErr(ctx, "download", err)
|
||||
if err != nil {
|
||||
return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload))
|
||||
}
|
||||
}
|
||||
|
||||
return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts)
|
||||
@@ -393,26 +380,28 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers())
|
||||
return
|
||||
return nil
|
||||
|
||||
default:
|
||||
// This may be a request timeout error or a request cancelled error.
|
||||
// Check it before moving further
|
||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
||||
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
|
||||
ierr := ierrors.Wrap(err, 0)
|
||||
if config.ReportDownloadingErrors {
|
||||
ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true))
|
||||
}
|
||||
|
||||
sendErr(ctx, "download", ierr)
|
||||
|
||||
if imagedata.FallbackImage == nil {
|
||||
panic(ierr)
|
||||
return ierr
|
||||
}
|
||||
|
||||
// We didn't panic, so the error is not reported.
|
||||
// Report it now
|
||||
// Just send error
|
||||
metrics.SendError(ctx, categoryDownload, ierr)
|
||||
|
||||
// We didn't return, so we have to report error
|
||||
if ierr.ShouldReport() {
|
||||
errorreport.Report(ierr, r)
|
||||
}
|
||||
@@ -433,27 +422,33 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
||||
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
|
||||
if config.ETagEnabled && statusCode == http.StatusOK {
|
||||
imgDataMatch, terr := etagHandler.SetActualImageData(originData, originHeaders)
|
||||
if terr == nil {
|
||||
rw.Header().Set("ETag", etagHandler.GenerateActualETag())
|
||||
imgDataMatch, eerr := etagHandler.SetActualImageData(originData, originHeaders)
|
||||
if eerr != nil && config.ReportIOErrors {
|
||||
return ierrors.Wrap(eerr, 0, ierrors.WithCategory(categoryIO))
|
||||
}
|
||||
|
||||
if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
|
||||
respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("ETag", etagHandler.GenerateActualETag())
|
||||
|
||||
if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
|
||||
respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
||||
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
|
||||
if !vips.SupportsLoad(originData.Format()) {
|
||||
sendErrAndPanic(ctx, "processing", newInvalidURLErrorf(
|
||||
return ierrors.Wrap(newInvalidURLErrorf(
|
||||
http.StatusUnprocessableEntity,
|
||||
"Source image format is not supported: %s", originData.Format(),
|
||||
))
|
||||
), 0, ierrors.WithCategory(categoryProcessing))
|
||||
}
|
||||
|
||||
result, err := func() (*processing.Result, error) {
|
||||
@@ -468,18 +463,31 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
defer result.OutData.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// First, check if the processing error wasn't caused by an image data error
|
||||
checkErr(ctx, "download", originData.Error())
|
||||
|
||||
// If it wasn't, than it was a processing error
|
||||
sendErrAndPanic(ctx, "processing", err)
|
||||
// First, check if the processing error wasn't caused by an image data error
|
||||
if originData.Error() != nil {
|
||||
return ierrors.Wrap(originData.Error(), 0, ierrors.WithCategory(categoryDownload))
|
||||
}
|
||||
|
||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
||||
// If it wasn't, than it was a processing error
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing))
|
||||
}
|
||||
|
||||
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
|
||||
writeDebugHeaders(rw, result)
|
||||
writeOriginContentLengthDebugHeader(ctx, rw, originData)
|
||||
|
||||
respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders)
|
||||
err = writeOriginContentLengthDebugHeader(rw, originData)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||
}
|
||||
|
||||
err = respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
"github.com/imgproxy/imgproxy/v3/svg"
|
||||
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
type ProcessingHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
router *router.Router
|
||||
router *server.Router
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) SetupSuite() {
|
||||
@@ -48,7 +48,7 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() {
|
||||
|
||||
logrus.SetOutput(io.Discard)
|
||||
|
||||
s.router = buildRouter()
|
||||
s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv()))
|
||||
}
|
||||
|
||||
func (s *ProcessingHandlerTestSuite) TeardownSuite() {
|
||||
|
@@ -7,12 +7,10 @@ import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
func Listen(network, address string) (net.Listener, error) {
|
||||
if config.SoReuseport {
|
||||
func Listen(network, address string, reuse bool) (net.Listener, error) {
|
||||
if reuse {
|
||||
log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version")
|
||||
}
|
||||
|
||||
|
@@ -10,12 +10,10 @@ import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
func Listen(network, address string) (net.Listener, error) {
|
||||
if !config.SoReuseport {
|
||||
func Listen(network, address string, reuse bool) (net.Listener, error) {
|
||||
if !reuse {
|
||||
return net.Listen(network, address)
|
||||
}
|
||||
|
||||
|
174
router/router.go
174
router/router.go
@@ -1,174 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
nanoid "github.com/matoous/go-nanoid/v2"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
const (
|
||||
xRequestIDHeader = "X-Request-ID"
|
||||
xAmznRequestContextHeader = "x-amzn-request-context"
|
||||
)
|
||||
|
||||
var (
|
||||
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
||||
)
|
||||
|
||||
type RouteHandler func(string, http.ResponseWriter, *http.Request)
|
||||
|
||||
type route struct {
|
||||
Method string
|
||||
Prefix string
|
||||
Handler RouteHandler
|
||||
Exact bool
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
prefix string
|
||||
healthRoutes []string
|
||||
faviconRoute string
|
||||
|
||||
Routes []*route
|
||||
HealthHandler RouteHandler
|
||||
}
|
||||
|
||||
func (r *route) isMatch(req *http.Request) bool {
|
||||
if r.Method != req.Method {
|
||||
return false
|
||||
}
|
||||
|
||||
if r.Exact {
|
||||
return req.URL.Path == r.Prefix
|
||||
}
|
||||
|
||||
return strings.HasPrefix(req.URL.Path, r.Prefix)
|
||||
}
|
||||
|
||||
func New(prefix string) *Router {
|
||||
healthRoutes := []string{prefix + "/health"}
|
||||
if len(config.HealthCheckPath) > 0 {
|
||||
healthRoutes = append(healthRoutes, prefix+config.HealthCheckPath)
|
||||
}
|
||||
|
||||
return &Router{
|
||||
prefix: prefix,
|
||||
healthRoutes: healthRoutes,
|
||||
faviconRoute: prefix + "/favicon.ico",
|
||||
Routes: make([]*route, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) Add(method, prefix string, handler RouteHandler, exact bool) {
|
||||
// Don't add routes with empty prefix
|
||||
if len(r.prefix+prefix) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.Routes = append(
|
||||
r.Routes,
|
||||
&route{Method: method, Prefix: r.prefix + prefix, Handler: handler, Exact: exact},
|
||||
)
|
||||
}
|
||||
|
||||
func (r *Router) GET(prefix string, handler RouteHandler, exact bool) {
|
||||
r.Add(http.MethodGet, prefix, handler, exact)
|
||||
}
|
||||
|
||||
func (r *Router) OPTIONS(prefix string, handler RouteHandler, exact bool) {
|
||||
r.Add(http.MethodOptions, prefix, handler, exact)
|
||||
}
|
||||
|
||||
func (r *Router) HEAD(prefix string, handler RouteHandler, exact bool) {
|
||||
r.Add(http.MethodHead, prefix, handler, exact)
|
||||
}
|
||||
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req, timeoutCancel := startRequestTimer(req)
|
||||
defer timeoutCancel()
|
||||
|
||||
rw = newTimeoutResponse(rw)
|
||||
|
||||
reqID := req.Header.Get(xRequestIDHeader)
|
||||
|
||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||
if lambdaContextVal := req.Header.Get(xAmznRequestContextHeader); len(lambdaContextVal) > 0 {
|
||||
var lambdaContext struct {
|
||||
RequestID string `json:"requestId"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
|
||||
if err == nil && len(lambdaContext.RequestID) > 0 {
|
||||
reqID = lambdaContext.RequestID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||
reqID, _ = nanoid.New()
|
||||
}
|
||||
|
||||
rw.Header().Set("Server", "imgproxy")
|
||||
rw.Header().Set(xRequestIDHeader, reqID)
|
||||
|
||||
if req.Method == http.MethodGet {
|
||||
if r.HealthHandler != nil {
|
||||
for _, healthRoute := range r.healthRoutes {
|
||||
if req.URL.Path == healthRoute {
|
||||
r.HealthHandler(reqID, rw, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.Path == r.faviconRoute {
|
||||
// TODO: Add a real favicon maybe?
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.WriteHeader(404)
|
||||
// Write a single byte to make AWS Lambda happy
|
||||
rw.Write([]byte{' '})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if ip := req.Header.Get("CF-Connecting-IP"); len(ip) != 0 {
|
||||
replaceRemoteAddr(req, ip)
|
||||
} else if ip := req.Header.Get("X-Forwarded-For"); len(ip) != 0 {
|
||||
if index := strings.Index(ip, ","); index > 0 {
|
||||
ip = ip[:index]
|
||||
}
|
||||
replaceRemoteAddr(req, ip)
|
||||
} else if ip := req.Header.Get("X-Real-IP"); len(ip) != 0 {
|
||||
replaceRemoteAddr(req, ip)
|
||||
}
|
||||
|
||||
LogRequest(reqID, req)
|
||||
|
||||
for _, rr := range r.Routes {
|
||||
if rr.isMatch(req) {
|
||||
rr.Handler(reqID, rw, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
LogResponse(reqID, req, 404, newRouteNotDefinedError(req.URL.Path))
|
||||
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.WriteHeader(404)
|
||||
rw.Write([]byte{' '})
|
||||
}
|
||||
|
||||
func replaceRemoteAddr(req *http.Request, ip string) {
|
||||
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
port = "80"
|
||||
}
|
||||
|
||||
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
|
||||
}
|
@@ -1,43 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
type timeoutResponse struct {
|
||||
http.ResponseWriter
|
||||
controller *http.ResponseController
|
||||
}
|
||||
|
||||
func newTimeoutResponse(rw http.ResponseWriter) http.ResponseWriter {
|
||||
return &timeoutResponse{
|
||||
ResponseWriter: rw,
|
||||
controller: http.NewResponseController(rw),
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *timeoutResponse) WriteHeader(statusCode int) {
|
||||
rw.withWriteDeadline(func() {
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
rw.withWriteDeadline(func() {
|
||||
n, err = rw.ResponseWriter.Write(b)
|
||||
})
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
||||
rw.controller.SetWriteDeadline(time.Now().Add(time.Duration(config.WriteResponseTimeout) * time.Second))
|
||||
defer rw.controller.SetWriteDeadline(time.Time{})
|
||||
f()
|
||||
}
|
204
server.go
204
server.go
@@ -1,204 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
golog "log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/netutil"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/metrics"
|
||||
"github.com/imgproxy/imgproxy/v3/reuseport"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/vips"
|
||||
)
|
||||
|
||||
var imgproxyIsRunningMsg = []byte("imgproxy is running")
|
||||
|
||||
func buildRouter() *router.Router {
|
||||
r := router.New(config.PathPrefix)
|
||||
|
||||
r.GET("/", handleLanding, true)
|
||||
r.GET("", handleLanding, true)
|
||||
|
||||
r.GET("/", withMetrics(withPanicHandler(withCORS(withSecret(handleProcessing)))), false)
|
||||
|
||||
r.HEAD("/", withCORS(handleHead), false)
|
||||
r.OPTIONS("/", withCORS(handleHead), false)
|
||||
|
||||
r.HealthHandler = handleHealth
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func startServer(cancel context.CancelFunc) (*http.Server, error) {
|
||||
l, err := reuseport.Listen(config.Network, config.Bind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't start server: %s", err)
|
||||
}
|
||||
|
||||
if config.MaxClients > 0 {
|
||||
l = netutil.LimitListener(l, config.MaxClients)
|
||||
}
|
||||
|
||||
errLogger := golog.New(
|
||||
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
|
||||
"", 0,
|
||||
)
|
||||
|
||||
s := &http.Server{
|
||||
Handler: buildRouter(),
|
||||
ReadTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
ErrorLog: errLogger,
|
||||
}
|
||||
|
||||
if config.KeepAliveTimeout > 0 {
|
||||
s.IdleTimeout = time.Duration(config.KeepAliveTimeout) * time.Second
|
||||
} else {
|
||||
s.SetKeepAlivesEnabled(false)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting server at %s", config.Bind)
|
||||
if err := s.Serve(l); err != nil && err != http.ErrServerClosed {
|
||||
log.Error(err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func shutdownServer(s *http.Server) {
|
||||
log.Info("Shutting down the server...")
|
||||
|
||||
ctx, close := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer close()
|
||||
|
||||
s.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func withMetrics(h router.RouteHandler) router.RouteHandler {
|
||||
if !metrics.Enabled() {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
ctx, metricsCancel, rw := metrics.StartRequest(r.Context(), rw, r)
|
||||
defer metricsCancel()
|
||||
|
||||
h(reqID, rw, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func withCORS(h router.RouteHandler) router.RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
if len(config.AllowOrigin) > 0 {
|
||||
rw.Header().Set("Access-Control-Allow-Origin", config.AllowOrigin)
|
||||
rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
}
|
||||
|
||||
h(reqID, rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func withSecret(h router.RouteHandler) router.RouteHandler {
|
||||
if len(config.Secret) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
authHeader := []byte(fmt.Sprintf("Bearer %s", config.Secret))
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), authHeader) == 1 {
|
||||
h(reqID, rw, r)
|
||||
} else {
|
||||
panic(newInvalidSecretError())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withPanicHandler(h router.RouteHandler) router.RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := errorreport.StartRequest(r)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
errorreport.SetMetadata(r, "Request ID", reqID)
|
||||
|
||||
defer func() {
|
||||
if rerr := recover(); rerr != nil {
|
||||
if rerr == http.ErrAbortHandler {
|
||||
panic(rerr)
|
||||
}
|
||||
|
||||
err, ok := rerr.(error)
|
||||
if !ok {
|
||||
panic(rerr)
|
||||
}
|
||||
|
||||
ierr := ierrors.Wrap(err, 0)
|
||||
|
||||
if ierr.ShouldReport() {
|
||||
errorreport.Report(err, r)
|
||||
}
|
||||
|
||||
router.LogResponse(reqID, r, ierr.StatusCode(), ierr)
|
||||
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.WriteHeader(ierr.StatusCode())
|
||||
|
||||
if config.DevelopmentErrorsMode {
|
||||
rw.Write([]byte(ierr.Error()))
|
||||
} else {
|
||||
rw.Write([]byte(ierr.PublicMessage()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
h(reqID, rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func handleHealth(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
status int
|
||||
msg []byte
|
||||
ierr *ierrors.Error
|
||||
)
|
||||
|
||||
if err := vips.Health(); err == nil {
|
||||
status = http.StatusOK
|
||||
msg = imgproxyIsRunningMsg
|
||||
} else {
|
||||
status = http.StatusInternalServerError
|
||||
msg = []byte("Error")
|
||||
ierr = ierrors.Wrap(err, 1)
|
||||
}
|
||||
|
||||
if len(msg) == 0 {
|
||||
msg = []byte{' '}
|
||||
}
|
||||
|
||||
// Log response only if something went wrong
|
||||
if ierr != nil {
|
||||
router.LogResponse(reqID, r, status, ierr)
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.Header().Set("Cache-Control", "no-cache")
|
||||
rw.WriteHeader(status)
|
||||
rw.Write(msg)
|
||||
}
|
||||
|
||||
func handleHead(reqID string, rw http.ResponseWriter, r *http.Request) {
|
||||
router.LogResponse(reqID, r, 200, nil)
|
||||
rw.WriteHeader(200)
|
||||
}
|
49
server/config.go
Normal file
49
server/config.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// gracefulTimeout represents graceful shutdown timeout
|
||||
gracefulTimeout = time.Duration(5 * time.Second)
|
||||
)
|
||||
|
||||
// Config represents HTTP server config
|
||||
type Config struct {
|
||||
Listen string // Address to listen on
|
||||
Network string // Network type (tcp, unix)
|
||||
Bind string // Bind address
|
||||
PathPrefix string // Path prefix for the server
|
||||
MaxClients int // Maximum number of concurrent clients
|
||||
ReadRequestTimeout time.Duration // Timeout for reading requests
|
||||
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
|
||||
GracefulTimeout time.Duration // Timeout for graceful shutdown
|
||||
CORSAllowOrigin string // CORS allowed origin
|
||||
Secret string // Secret for authorization
|
||||
DevelopmentErrorsMode bool // Enable development mode for detailed error messages
|
||||
SocketReusePort bool // Enable SO_REUSEPORT socket option
|
||||
HealthCheckPath string // Health check path from config
|
||||
WriteResponseTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewConfigFromEnv creates a new Config instance from environment variables
|
||||
func NewConfigFromEnv() *Config {
|
||||
return &Config{
|
||||
Network: config.Network,
|
||||
Bind: config.Bind,
|
||||
PathPrefix: config.PathPrefix,
|
||||
MaxClients: config.MaxClients,
|
||||
ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
|
||||
KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second,
|
||||
GracefulTimeout: gracefulTimeout,
|
||||
CORSAllowOrigin: config.AllowOrigin,
|
||||
Secret: config.Secret,
|
||||
DevelopmentErrorsMode: config.DevelopmentErrorsMode,
|
||||
SocketReusePort: config.SoReuseport,
|
||||
HealthCheckPath: config.HealthCheckPath,
|
||||
WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second,
|
||||
}
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
package router
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ type (
|
||||
RouteNotDefinedError string
|
||||
RequestCancelledError string
|
||||
RequestTimeoutError string
|
||||
InvalidSecretError struct{}
|
||||
)
|
||||
|
||||
func newRouteNotDefinedError(path string) *ierrors.Error {
|
||||
@@ -33,11 +35,16 @@ func newRequestCancelledError(after time.Duration) *ierrors.Error {
|
||||
ierrors.WithStatusCode(499),
|
||||
ierrors.WithPublicMessage("Cancelled"),
|
||||
ierrors.WithShouldReport(false),
|
||||
ierrors.WithCategory(categoryTimeout),
|
||||
)
|
||||
}
|
||||
|
||||
func (e RequestCancelledError) Error() string { return string(e) }
|
||||
|
||||
func (e RequestCancelledError) Unwrap() error {
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
func newRequestTimeoutError(after time.Duration) *ierrors.Error {
|
||||
return ierrors.Wrap(
|
||||
RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)),
|
||||
@@ -45,7 +52,24 @@ func newRequestTimeoutError(after time.Duration) *ierrors.Error {
|
||||
ierrors.WithStatusCode(http.StatusServiceUnavailable),
|
||||
ierrors.WithPublicMessage("Gateway Timeout"),
|
||||
ierrors.WithShouldReport(false),
|
||||
ierrors.WithCategory(categoryTimeout),
|
||||
)
|
||||
}
|
||||
|
||||
func (e RequestTimeoutError) Error() string { return string(e) }
|
||||
|
||||
func (e RequestTimeoutError) Unwrap() error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func newInvalidSecretError() error {
|
||||
return ierrors.Wrap(
|
||||
InvalidSecretError{},
|
||||
1,
|
||||
ierrors.WithStatusCode(http.StatusForbidden),
|
||||
ierrors.WithPublicMessage("Forbidden"),
|
||||
ierrors.WithShouldReport(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (e InvalidSecretError) Error() string { return "Invalid secret" }
|
@@ -1,4 +1,4 @@
|
||||
package router
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -59,6 +59,6 @@ func LogResponse(reqID string, r *http.Request, status int, err *ierrors.Error,
|
||||
|
||||
log.WithFields(fields).Logf(
|
||||
level,
|
||||
"Completed in %s %s", ctxTime(r.Context()), r.RequestURI,
|
||||
"Completed in %s %s", requestStartedAt(r.Context()), r.RequestURI,
|
||||
)
|
||||
}
|
145
server/middlewares.go
Normal file
145
server/middlewares.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
categoryTimeout = "timeout"
|
||||
)
|
||||
|
||||
// WithMetrics wraps RouteHandler with metrics handling.
|
||||
func (r *Router) WithMetrics(h RouteHandler) RouteHandler {
|
||||
if !metrics.Enabled() {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
ctx, metricsCancel, rw := metrics.StartRequest(req.Context(), rw, req)
|
||||
defer metricsCancel()
|
||||
|
||||
return h(reqID, rw, req.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// WithCORS wraps RouteHandler with CORS handling
|
||||
func (r *Router) WithCORS(h RouteHandler) RouteHandler {
|
||||
if len(r.config.CORSAllowOrigin) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
|
||||
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
|
||||
|
||||
return h(reqID, rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSecret wraps RouteHandler with secret handling
|
||||
func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
||||
if len(r.config.Secret) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
|
||||
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
|
||||
return h(reqID, rw, req)
|
||||
} else {
|
||||
return newInvalidSecretError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPanic recovers panic and converts it to normal error
|
||||
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
|
||||
defer func() {
|
||||
// try to recover from panic
|
||||
rerr := recover()
|
||||
if rerr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// abort handler is an exception of net/http, we should simply repanic it.
|
||||
// it will supress the stack trace
|
||||
if rerr == http.ErrAbortHandler {
|
||||
panic(rerr)
|
||||
}
|
||||
|
||||
// let's recover error value from panic if it has panicked with error
|
||||
err, ok := rerr.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("panic: %v", err)
|
||||
}
|
||||
|
||||
retErr = err
|
||||
}()
|
||||
|
||||
return h(reqID, rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
// WithReportError handles error reporting.
|
||||
// It should be placed after `WithMetrics`, but before `WithPanic`.
|
||||
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
// Open the error context
|
||||
ctx := errorreport.StartRequest(req)
|
||||
req = req.WithContext(ctx)
|
||||
errorreport.SetMetadata(req, "Request ID", reqID)
|
||||
|
||||
// Call the underlying handler passing the context downwards
|
||||
err := h(reqID, rw, req)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wrap a resulting error into ierrors.Error
|
||||
ierr := ierrors.Wrap(err, 0)
|
||||
|
||||
// Get the error category
|
||||
errCat := ierr.Category()
|
||||
|
||||
// Exception: any context.DeadlineExceeded error is timeout
|
||||
if errors.Is(ierr, context.DeadlineExceeded) {
|
||||
errCat = categoryTimeout
|
||||
}
|
||||
|
||||
// We do not need to send any canceled context
|
||||
if !errors.Is(ierr, context.Canceled) {
|
||||
metrics.SendError(ctx, errCat, err)
|
||||
}
|
||||
|
||||
// Report error to error collectors
|
||||
if ierr.ShouldReport() {
|
||||
errorreport.Report(ierr, req)
|
||||
}
|
||||
|
||||
// Log response and format the error output
|
||||
LogResponse(reqID, req, ierr.StatusCode(), ierr)
|
||||
|
||||
// Error message: either is public message or full development error
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.WriteHeader(ierr.StatusCode())
|
||||
|
||||
if r.config.DevelopmentErrorsMode {
|
||||
rw.Write([]byte(ierr.Error()))
|
||||
} else {
|
||||
rw.Write([]byte(ierr.PublicMessage()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
209
server/router.go
Normal file
209
server/router.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
nanoid "github.com/matoous/go-nanoid/v2"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultServerName is the default name of the server
|
||||
defaultServerName = "imgproxy"
|
||||
)
|
||||
|
||||
var (
|
||||
// requestIDRe is a regular expression for validating request IDs
|
||||
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
||||
)
|
||||
|
||||
// RouteHandler is a function that handles HTTP requests.
|
||||
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
|
||||
|
||||
// Middleware is a function that wraps a RouteHandler with additional functionality.
|
||||
type Middleware func(next RouteHandler) RouteHandler
|
||||
|
||||
// route represents a single route in the router.
|
||||
type route struct {
|
||||
method string // method is the HTTP method for a route
|
||||
path string // path represents a route path
|
||||
exact bool // exact means that path must match exactly, otherwise any prefixed matches
|
||||
handler RouteHandler // handler is the function that handles the route
|
||||
silent bool // Silent route (no logs)
|
||||
}
|
||||
|
||||
// Router is responsible for routing HTTP requests
|
||||
type Router struct {
|
||||
// config represents the server configuration
|
||||
config *Config
|
||||
|
||||
// routes is the collection of all routes
|
||||
routes []*route
|
||||
}
|
||||
|
||||
// NewRouter creates a new Router instance
|
||||
func NewRouter(config *Config) *Router {
|
||||
return &Router{config: config}
|
||||
}
|
||||
|
||||
// add adds an abitary route to the router
|
||||
func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||
for _, m := range middlewares {
|
||||
handler = m(handler)
|
||||
}
|
||||
|
||||
route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact}
|
||||
|
||||
r.routes = append(
|
||||
r.routes,
|
||||
route,
|
||||
)
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// GET adds GET route
|
||||
func (r *Router) GET(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||
return r.add(http.MethodGet, prefix, exact, handler, middlewares...)
|
||||
}
|
||||
|
||||
// OPTIONS adds OPTIONS route
|
||||
func (r *Router) OPTIONS(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||
return r.add(http.MethodOptions, prefix, exact, handler, middlewares...)
|
||||
}
|
||||
|
||||
// HEAD adds HEAD route
|
||||
func (r *Router) HEAD(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||
return r.add(http.MethodHead, prefix, exact, handler, middlewares...)
|
||||
}
|
||||
|
||||
// ServeHTTP serves routes
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Attach timer to the context
|
||||
req, timeoutCancel := startRequestTimer(req)
|
||||
defer timeoutCancel()
|
||||
|
||||
// Create the response writer which times out on write
|
||||
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
|
||||
|
||||
// Get/create request ID
|
||||
reqID := r.getRequestID(req)
|
||||
|
||||
// Replace request IP from headers
|
||||
r.replaceRemoteAddr(req)
|
||||
|
||||
rw.Header().Set(httpheaders.Server, defaultServerName)
|
||||
rw.Header().Set(httpheaders.XRequestID, reqID)
|
||||
|
||||
for _, rr := range r.routes {
|
||||
if rr.isMatch(req) {
|
||||
if !rr.silent {
|
||||
LogRequest(reqID, req)
|
||||
}
|
||||
|
||||
rr.handler(reqID, rw, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Means that we have not found matching route
|
||||
LogRequest(reqID, req)
|
||||
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
|
||||
r.NotFoundHandler(reqID, rw, req)
|
||||
}
|
||||
|
||||
// NotFoundHandler is default 404 handler
|
||||
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OkHandler is a default 200 OK handler
|
||||
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRequestID tries to read request id from headers or from lambda
|
||||
// context or generates a new one if nothing found.
|
||||
func (r *Router) getRequestID(req *http.Request) string {
|
||||
// Get request ID from headers (if any)
|
||||
reqID := req.Header.Get(httpheaders.XRequestID)
|
||||
|
||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||
lambdaContextVal := req.Header.Get(httpheaders.XAmznRequestContextHeader)
|
||||
|
||||
if len(lambdaContextVal) > 0 {
|
||||
var lambdaContext struct {
|
||||
RequestID string `json:"requestId"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
|
||||
if err == nil && len(lambdaContext.RequestID) > 0 {
|
||||
reqID = lambdaContext.RequestID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||
reqID, _ = nanoid.New()
|
||||
}
|
||||
|
||||
return reqID
|
||||
}
|
||||
|
||||
// replaceRemoteAddr rewrites the req.RemoteAddr property from request headers
|
||||
func (r *Router) replaceRemoteAddr(req *http.Request) {
|
||||
cfConnectingIP := req.Header.Get(httpheaders.CFConnectingIP)
|
||||
xForwardedFor := req.Header.Get(httpheaders.XForwardedFor)
|
||||
xRealIP := req.Header.Get(httpheaders.XRealIP)
|
||||
|
||||
switch {
|
||||
case len(cfConnectingIP) > 0:
|
||||
replaceRemoteAddr(req, cfConnectingIP)
|
||||
case len(xForwardedFor) > 0:
|
||||
if index := strings.Index(xForwardedFor, ","); index > 0 {
|
||||
xForwardedFor = xForwardedFor[:index]
|
||||
}
|
||||
replaceRemoteAddr(req, xForwardedFor)
|
||||
case len(xRealIP) > 0:
|
||||
replaceRemoteAddr(req, xRealIP)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceRemoteAddr sets the req.RemoteAddr for request
|
||||
func replaceRemoteAddr(req *http.Request, ip string) {
|
||||
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
port = "80"
|
||||
}
|
||||
|
||||
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
|
||||
}
|
||||
|
||||
// isMatch checks that a request matches route
|
||||
func (r *route) isMatch(req *http.Request) bool {
|
||||
methodMatches := r.method == req.Method
|
||||
notExactPathMathes := !r.exact && strings.HasPrefix(req.URL.Path, r.path)
|
||||
exactPathMatches := r.exact && req.URL.Path == r.path
|
||||
|
||||
return methodMatches && (notExactPathMathes || exactPathMatches)
|
||||
}
|
||||
|
||||
// Silent sets Silent flag which supresses logs to true. We do not need to log
|
||||
// requests like /health of /favicon.ico
|
||||
func (r *route) Silent() *route {
|
||||
r.silent = true
|
||||
return r
|
||||
}
|
296
server/router_test.go
Normal file
296
server/router_test.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
)
|
||||
|
||||
type RouterTestSuite struct {
|
||||
suite.Suite
|
||||
router *Router
|
||||
}
|
||||
|
||||
func (s *RouterTestSuite) SetupTest() {
|
||||
c := NewConfigFromEnv()
|
||||
c.PathPrefix = "/api"
|
||||
s.router = NewRouter(c)
|
||||
}
|
||||
|
||||
func TestRouterSuite(t *testing.T) {
|
||||
suite.Run(t, new(RouterTestSuite))
|
||||
}
|
||||
|
||||
// TestHTTPMethods tests route methods registration and HTTP requests
|
||||
func (s *RouterTestSuite) TestHTTPMethods() {
|
||||
var capturedMethod string
|
||||
var capturedPath string
|
||||
|
||||
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("GET response"))
|
||||
return nil
|
||||
}
|
||||
|
||||
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("OPTIONS response"))
|
||||
return nil
|
||||
}
|
||||
|
||||
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
capturedMethod = req.Method
|
||||
capturedPath = req.URL.Path
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register routes with different configurations
|
||||
s.router.GET("/get-test", true, getHandler) // exact match
|
||||
s.router.OPTIONS("/options-test", false, optionsHandler) // prefix match
|
||||
s.router.HEAD("/head-test", true, headHandler) // exact match
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestMethod string
|
||||
requestPath string
|
||||
expectedBody string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "GET",
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/api/get-test",
|
||||
expectedBody: "GET response",
|
||||
expectedPath: "/api/get-test",
|
||||
},
|
||||
{
|
||||
name: "OPTIONS",
|
||||
requestMethod: http.MethodOptions,
|
||||
requestPath: "/api/options-test",
|
||||
expectedBody: "OPTIONS response",
|
||||
expectedPath: "/api/options-test",
|
||||
},
|
||||
{
|
||||
name: "OPTIONSPrefixed",
|
||||
requestMethod: http.MethodOptions,
|
||||
requestPath: "/api/options-test/sub",
|
||||
expectedBody: "OPTIONS response",
|
||||
expectedPath: "/api/options-test/sub",
|
||||
},
|
||||
{
|
||||
name: "HEAD",
|
||||
requestMethod: http.MethodHead,
|
||||
requestPath: "/api/head-test",
|
||||
expectedBody: "",
|
||||
expectedPath: "/api/head-test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest(tt.requestMethod, tt.requestPath, nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
s.Require().Equal(tt.expectedBody, rw.Body.String())
|
||||
s.Require().Equal(tt.requestMethod, capturedMethod)
|
||||
s.Require().Equal(tt.expectedPath, capturedPath)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareOrder checks middleware ordering and functionality
|
||||
func (s *RouterTestSuite) TestMiddlewareOrder() {
|
||||
var order []string
|
||||
|
||||
middleware1 := func(next RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "middleware1")
|
||||
return next(reqID, rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
middleware2 := func(next RouteHandler) RouteHandler {
|
||||
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "middleware2")
|
||||
return next(reqID, rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
order = append(order, "handler")
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.router.GET("/test", true, handler, middleware2, middleware1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
// Middleware should execute in the order they are passed (first added first)
|
||||
s.Require().Equal([]string{"middleware1", "middleware2", "handler"}, order)
|
||||
}
|
||||
|
||||
// TestServeHTTP tests ServeHTTP method
|
||||
func (s *RouterTestSuite) TestServeHTTP() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.Header().Set("Custom-Header", "test-value")
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("success"))
|
||||
return nil
|
||||
}
|
||||
|
||||
s.router.GET("/test", true, handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
s.Require().Equal(200, rw.Code)
|
||||
s.Require().Equal("success", rw.Body.String())
|
||||
s.Require().Equal("test-value", rw.Header().Get("Custom-Header"))
|
||||
s.Require().Equal(defaultServerName, rw.Header().Get(httpheaders.Server))
|
||||
s.Require().NotEmpty(rw.Header().Get(httpheaders.XRequestID))
|
||||
}
|
||||
|
||||
// TestRequestID checks request ID generation and validation
|
||||
func (s *RouterTestSuite) TestRequestID() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.router.GET("/test", true, handler)
|
||||
|
||||
// Test request ID passthrough (if present)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set(httpheaders.XRequestID, "valid-id-123")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
s.Require().Equal("valid-id-123", rw.Header().Get(httpheaders.XRequestID))
|
||||
|
||||
// Test invalid request ID (should generate a new one)
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set(httpheaders.XRequestID, "invalid id with spaces!")
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
generatedID := rw.Header().Get(httpheaders.XRequestID)
|
||||
s.Require().NotEqual("invalid id with spaces!", generatedID)
|
||||
s.Require().NotEmpty(generatedID)
|
||||
|
||||
// Test no request ID (should generate a new one)
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
generatedID = rw.Header().Get(httpheaders.XRequestID)
|
||||
s.Require().NotEmpty(generatedID)
|
||||
s.Require().Regexp(`^[A-Za-z0-9_\-]+$`, generatedID)
|
||||
}
|
||||
|
||||
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
|
||||
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.router.GET("/test", true, handler)
|
||||
|
||||
// Test with valid Lambda context
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set(httpheaders.XAmznRequestContextHeader, `{"requestId":"lambda-req-123"}`)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
s.Require().Equal("lambda-req-123", rw.Header().Get(httpheaders.XRequestID))
|
||||
}
|
||||
|
||||
// Test IP address handling
|
||||
func (s *RouterTestSuite) TestReplaceIP() {
|
||||
var capturedRemoteAddr string
|
||||
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||
capturedRemoteAddr = req.RemoteAddr
|
||||
rw.WriteHeader(200)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.router.GET("/test", true, handler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
originalAddr string
|
||||
headers map[string]string
|
||||
expectedAddr string
|
||||
}{
|
||||
{
|
||||
name: "CFConnectingIP",
|
||||
originalAddr: "original:8080",
|
||||
headers: map[string]string{
|
||||
httpheaders.CFConnectingIP: "1.2.3.4",
|
||||
},
|
||||
expectedAddr: "1.2.3.4:8080",
|
||||
},
|
||||
{
|
||||
name: "XForwardedForMulti",
|
||||
originalAddr: "original:8080",
|
||||
headers: map[string]string{
|
||||
httpheaders.XForwardedFor: "5.6.7.8, 9.10.11.12",
|
||||
},
|
||||
expectedAddr: "5.6.7.8:8080",
|
||||
},
|
||||
{
|
||||
name: "XForwardedForSingle",
|
||||
originalAddr: "original:8080",
|
||||
headers: map[string]string{
|
||||
httpheaders.XForwardedFor: "13.14.15.16",
|
||||
},
|
||||
expectedAddr: "13.14.15.16:8080",
|
||||
},
|
||||
{
|
||||
name: "XRealIP",
|
||||
originalAddr: "original:8080",
|
||||
headers: map[string]string{
|
||||
httpheaders.XRealIP: "17.18.19.20",
|
||||
},
|
||||
expectedAddr: "17.18.19.20:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.RemoteAddr = tt.originalAddr
|
||||
|
||||
for header, value := range tt.headers {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(rw, req)
|
||||
|
||||
s.Require().Equal(tt.expectedAddr, capturedRemoteAddr)
|
||||
})
|
||||
}
|
||||
}
|
82
server/server.go
Normal file
82
server/server.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
golog "log"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/netutil"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/reuseport"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxHeaderBytes represents max bytes in request header
|
||||
maxHeaderBytes = 1 << 20
|
||||
)
|
||||
|
||||
// Server represents the HTTP server wrapper struct
|
||||
type Server struct {
|
||||
router *Router
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// Start starts the http server. cancel is called in case server failed to start, but it happened
|
||||
// asynchronously. It should cancel the upstream context.
|
||||
func Start(cancel context.CancelFunc, router *Router) (*Server, error) {
|
||||
l, err := reuseport.Listen(router.config.Network, router.config.Bind, router.config.SocketReusePort)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("can't start server: %s", err)
|
||||
}
|
||||
|
||||
if router.config.MaxClients > 0 {
|
||||
l = netutil.LimitListener(l, router.config.MaxClients)
|
||||
}
|
||||
|
||||
errLogger := golog.New(
|
||||
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
|
||||
"", 0,
|
||||
)
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: router,
|
||||
ReadTimeout: router.config.ReadRequestTimeout,
|
||||
MaxHeaderBytes: maxHeaderBytes,
|
||||
ErrorLog: errLogger,
|
||||
}
|
||||
|
||||
if config.KeepAliveTimeout > 0 {
|
||||
srv.IdleTimeout = router.config.KeepAliveTimeout
|
||||
} else {
|
||||
srv.SetKeepAlivesEnabled(false)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting server at %s", router.config.Bind)
|
||||
|
||||
if err := srv.Serve(l); err != nil && err != http.ErrServerClosed {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return &Server{
|
||||
router: router,
|
||||
server: srv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *Server) Shutdown(ctx context.Context) {
|
||||
log.Info("Shutting down the server...")
|
||||
|
||||
ctx, close := context.WithTimeout(ctx, s.router.config.GracefulTimeout)
|
||||
defer close()
|
||||
|
||||
s.server.Shutdown(ctx)
|
||||
}
|
268
server/server_test.go
Normal file
268
server/server_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ServerTestSuite struct {
|
||||
suite.Suite
|
||||
config *Config
|
||||
blankRouter *Router
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) SetupTest() {
|
||||
config.Reset()
|
||||
s.config = NewConfigFromEnv()
|
||||
s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
|
||||
s.blankRouter = NewRouter(s.config)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
|
||||
ctx, cancel := context.WithCancel(s.T().Context())
|
||||
|
||||
// Track if cancel was called using atomic
|
||||
var cancelCalled atomic.Bool
|
||||
cancelWrapper := func() {
|
||||
cancel()
|
||||
cancelCalled.Store(true)
|
||||
}
|
||||
|
||||
invalidConfig := &Config{
|
||||
Network: "tcp",
|
||||
Bind: "invalid-address", // Invalid address
|
||||
}
|
||||
|
||||
r := NewRouter(invalidConfig)
|
||||
|
||||
server, err := Start(cancelWrapper, r)
|
||||
|
||||
s.Require().Error(err)
|
||||
s.Nil(server)
|
||||
s.Contains(err.Error(), "can't start server")
|
||||
|
||||
// Check if cancel was called using Eventually
|
||||
s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
// Also verify the context was cancelled
|
||||
s.Require().Eventually(func() bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestShutdown() {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server, err := Start(cancel, s.blankRouter)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(server)
|
||||
|
||||
// Test graceful shutdown
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Should not panic or hang
|
||||
s.NotPanics(func() {
|
||||
server.Shutdown(shutdownCtx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestWithCORS() {
|
||||
tests := []struct {
|
||||
name string
|
||||
corsAllowOrigin string
|
||||
expectedOrigin string
|
||||
expectedMethods string
|
||||
}{
|
||||
{
|
||||
name: "WithCORSOrigin",
|
||||
corsAllowOrigin: "https://example.com",
|
||||
expectedOrigin: "https://example.com",
|
||||
expectedMethods: "GET, OPTIONS",
|
||||
},
|
||||
{
|
||||
name: "NoCORSOrigin",
|
||||
corsAllowOrigin: "",
|
||||
expectedOrigin: "",
|
||||
expectedMethods: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
config := &Config{
|
||||
CORSAllowOrigin: tt.corsAllowOrigin,
|
||||
}
|
||||
router := NewRouter(config)
|
||||
|
||||
wrappedHandler := router.WithCORS(s.mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
|
||||
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
|
||||
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestWithSecret() {
|
||||
tests := []struct {
|
||||
name string
|
||||
secret string
|
||||
authHeader string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "ValidSecret",
|
||||
secret: "test-secret",
|
||||
authHeader: "Bearer test-secret",
|
||||
},
|
||||
{
|
||||
name: "InvalidSecret",
|
||||
secret: "foo-secret",
|
||||
authHeader: "Bearer wrong-secret",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "NoSecretConfigured",
|
||||
secret: "",
|
||||
authHeader: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
config := &Config{
|
||||
Secret: tt.secret,
|
||||
}
|
||||
router := NewRouter(config)
|
||||
|
||||
wrappedHandler := router.WithSecret(s.mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set(httpheaders.Authorization, tt.authHeader)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
err := wrappedHandler("test-req-id", rw, req)
|
||||
|
||||
if tt.expectError {
|
||||
s.Require().Error(err)
|
||||
} else {
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoSuccess() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
|
||||
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
|
||||
s.Equal(http.StatusOK, rw.Code)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoWithError() {
|
||||
testError := errors.New("test error")
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
return testError
|
||||
}
|
||||
|
||||
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
|
||||
s.Equal(http.StatusInternalServerError, rw.Code)
|
||||
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||
testError := errors.New("panic error")
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
panic(testError)
|
||||
}
|
||||
|
||||
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.NotPanics(func() {
|
||||
err := wrappedHandler("test-req-id", rw, req)
|
||||
s.Require().Error(err, "panic error")
|
||||
})
|
||||
|
||||
s.Equal(http.StatusOK, rw.Code)
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Should re-panic with ErrAbortHandler
|
||||
s.Panics(func() {
|
||||
wrappedHandler("test-req-id", rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
||||
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||
panic("string panic")
|
||||
}
|
||||
|
||||
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Should re-panic with non-error panics
|
||||
s.NotPanics(func() {
|
||||
err := wrappedHandler("test-req-id", rw, req)
|
||||
s.Require().Error(err, "string panic")
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ServerTestSuite))
|
||||
}
|
52
server/timeout_response.go
Normal file
52
server/timeout_response.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// timeoutResponse manages response writer with timeout. It has
|
||||
// timeout on all write methods.
|
||||
type timeoutResponse struct {
|
||||
http.ResponseWriter
|
||||
controller *http.ResponseController
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newTimeoutResponse creates a new timeoutResponse
|
||||
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
|
||||
return &timeoutResponse{
|
||||
ResponseWriter: rw,
|
||||
controller: http.NewResponseController(rw),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements http.ResponseWriter.Write
|
||||
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
rw.withWriteDeadline(func() {
|
||||
n, err = rw.ResponseWriter.Write(b)
|
||||
})
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Header returns current HTTP headers
|
||||
func (rw *timeoutResponse) Header() http.Header {
|
||||
return rw.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// withWriteDeadline executes a Write* function with a deadline
|
||||
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
||||
deadline := time.Now().Add(rw.timeout)
|
||||
|
||||
// Set write deadline
|
||||
rw.controller.SetWriteDeadline(deadline)
|
||||
|
||||
// Reset write deadline after method has finished
|
||||
defer rw.controller.SetWriteDeadline(time.Time{})
|
||||
f()
|
||||
}
|
@@ -1,4 +1,6 @@
|
||||
package router
|
||||
// timer.go contains methods for storing, retrieving and checking
|
||||
// timer in a request context.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,8 +11,10 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
)
|
||||
|
||||
// timerSinceCtxKey represents a context key for start time.
|
||||
type timerSinceCtxKey struct{}
|
||||
|
||||
// startRequestTimer starts a new request timer.
|
||||
func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now())
|
||||
@@ -18,17 +22,20 @@ func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
|
||||
return r.WithContext(ctx), cancel
|
||||
}
|
||||
|
||||
func ctxTime(ctx context.Context) time.Duration {
|
||||
// requestStartedAt returns the duration since the timer started in the context.
|
||||
func requestStartedAt(ctx context.Context) time.Duration {
|
||||
if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok {
|
||||
return time.Since(t)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// CheckTimeout checks if the request context has timed out or cancelled and returns
|
||||
// wrapped error.
|
||||
func CheckTimeout(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
d := ctxTime(ctx)
|
||||
d := requestStartedAt(ctx)
|
||||
|
||||
err := ctx.Err()
|
||||
switch err {
|
||||
@@ -37,7 +44,7 @@ func CheckTimeout(ctx context.Context) error {
|
||||
case context.DeadlineExceeded:
|
||||
return newRequestTimeoutError(d)
|
||||
default:
|
||||
return ierrors.Wrap(err, 0)
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryTimeout))
|
||||
}
|
||||
default:
|
||||
return nil
|
67
server/timer_test.go
Normal file
67
server/timer_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCheckTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() context.Context
|
||||
fail bool
|
||||
}{
|
||||
{
|
||||
name: "WithoutTimeout",
|
||||
setup: context.Background,
|
||||
fail: false,
|
||||
},
|
||||
{
|
||||
name: "ActiveTimerContext",
|
||||
setup: func() context.Context {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
newReq, _ := startRequestTimer(req)
|
||||
return newReq.Context()
|
||||
},
|
||||
fail: false,
|
||||
},
|
||||
{
|
||||
name: "CancelledContext",
|
||||
setup: func() context.Context {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
newReq, cancel := startRequestTimer(req)
|
||||
cancel() // Cancel immediately
|
||||
return newReq.Context()
|
||||
},
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "DeadlineExceeded",
|
||||
setup: func() context.Context {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
|
||||
defer cancel()
|
||||
time.Sleep(time.Millisecond * 10) // Ensure timeout
|
||||
return ctx
|
||||
},
|
||||
fail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := tt.setup()
|
||||
err := CheckTimeout(ctx)
|
||||
|
||||
if tt.fail {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
21
stream.go
21
stream.go
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
"github.com/imgproxy/imgproxy/v3/cookies"
|
||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||
"github.com/imgproxy/imgproxy/v3/metrics"
|
||||
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
||||
"github.com/imgproxy/imgproxy/v3/options"
|
||||
"github.com/imgproxy/imgproxy/v3/router"
|
||||
"github.com/imgproxy/imgproxy/v3/server"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -44,7 +45,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) {
|
||||
func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error {
|
||||
stats.IncImagesInProgress()
|
||||
defer stats.DecImagesInProgress()
|
||||
|
||||
@@ -65,18 +66,24 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
||||
|
||||
if config.CookiePassthrough {
|
||||
cookieJar, err = cookies.JarFromRequest(r)
|
||||
checkErr(ctx, "streaming", err)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||
}
|
||||
}
|
||||
|
||||
req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
|
||||
defer req.Cancel()
|
||||
checkErr(ctx, "streaming", err)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||
}
|
||||
|
||||
res, err := req.Send()
|
||||
if res != nil {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
checkErr(ctx, "streaming", err)
|
||||
if err != nil {
|
||||
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||
}
|
||||
|
||||
for _, k := range streamRespHeaders {
|
||||
vv := res.Header.Values(k)
|
||||
@@ -116,7 +123,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
||||
copyerr = nil
|
||||
}
|
||||
|
||||
router.LogResponse(
|
||||
server.LogResponse(
|
||||
reqID, r, res.StatusCode, nil,
|
||||
log.Fields{
|
||||
"image_url": imageURL,
|
||||
@@ -127,4 +134,6 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
||||
if copyerr != nil {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user