mirror of
https://github.com/imgproxy/imgproxy.git
synced 2025-09-28 20:43:54 +02:00
IMG-51: router -> server ns, handlers ns, added error to handler ret val (#1494)
* Introduced server, handlers, error ret in handlerfn * Server struct with tests * replace checkErr with return
This commit is contained in:
26
errors.go
26
errors.go
@@ -7,11 +7,23 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Monitoring error categories
|
||||||
|
const (
|
||||||
|
categoryTimeout = "timeout"
|
||||||
|
categoryImageDataSize = "image_data_size"
|
||||||
|
categoryPathParsing = "path_parsing"
|
||||||
|
categorySecurity = "security"
|
||||||
|
categoryQueue = "queue"
|
||||||
|
categoryDownload = "download"
|
||||||
|
categoryProcessing = "processing"
|
||||||
|
categoryIO = "IO"
|
||||||
|
categoryStreaming = "streaming"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
ResponseWriteError struct{ error }
|
ResponseWriteError struct{ error }
|
||||||
InvalidURLError string
|
InvalidURLError string
|
||||||
TooManyRequestsError struct{}
|
TooManyRequestsError struct{}
|
||||||
InvalidSecretError struct{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newResponseWriteError(cause error) *ierrors.Error {
|
func newResponseWriteError(cause error) *ierrors.Error {
|
||||||
@@ -53,15 +65,3 @@ func newTooManyRequestsError() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e TooManyRequestsError) Error() string { return "Too many requests" }
|
func (e TooManyRequestsError) Error() string { return "Too many requests" }
|
||||||
|
|
||||||
func newInvalidSecretError() error {
|
|
||||||
return ierrors.Wrap(
|
|
||||||
InvalidSecretError{},
|
|
||||||
1,
|
|
||||||
ierrors.WithStatusCode(http.StatusForbidden),
|
|
||||||
ierrors.WithPublicMessage("Forbidden"),
|
|
||||||
ierrors.WithShouldReport(false),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e InvalidSecretError) Error() string { return "Invalid secret" }
|
|
||||||
|
46
handlers/health.go
Normal file
46
handlers/health.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
|
)
|
||||||
|
|
||||||
|
var imgproxyIsRunningMsg = []byte("imgproxy is running")
|
||||||
|
|
||||||
|
// HealthHandler handles the health check requests
|
||||||
|
func HealthHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
var (
|
||||||
|
status int
|
||||||
|
msg []byte
|
||||||
|
ierr *ierrors.Error
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := vips.Health(); err == nil {
|
||||||
|
status = http.StatusOK
|
||||||
|
msg = imgproxyIsRunningMsg
|
||||||
|
} else {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
msg = []byte("Error")
|
||||||
|
ierr = ierrors.Wrap(err, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg) == 0 {
|
||||||
|
msg = []byte{' '}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log response only if something went wrong
|
||||||
|
if ierr != nil {
|
||||||
|
server.LogResponse(reqID, r, status, ierr)
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
|
rw.Header().Set(httpheaders.CacheControl, "no-cache")
|
||||||
|
rw.WriteHeader(status)
|
||||||
|
rw.Write(msg)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
32
handlers/health_test.go
Normal file
32
handlers/health_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthHandler(t *testing.T) {
|
||||||
|
// Create a ResponseRecorder to record the response
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler function directly (no need for actual HTTP request)
|
||||||
|
HealthHandler("test-req-id", rr, nil)
|
||||||
|
|
||||||
|
// Check that we get a valid response (either 200 or 500 depending on vips state)
|
||||||
|
assert.True(t, rr.Code == http.StatusOK || rr.Code == http.StatusInternalServerError)
|
||||||
|
|
||||||
|
// Check headers are set correctly
|
||||||
|
assert.Equal(t, "text/plain", rr.Header().Get(httpheaders.ContentType))
|
||||||
|
assert.Equal(t, "no-cache", rr.Header().Get(httpheaders.CacheControl))
|
||||||
|
|
||||||
|
// Verify response format and content
|
||||||
|
body := rr.Body.String()
|
||||||
|
assert.NotEmpty(t, body)
|
||||||
|
|
||||||
|
assert.Equal(t, "imgproxy is running", body)
|
||||||
|
}
|
@@ -1,6 +1,10 @@
|
|||||||
package main
|
package handlers
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
var landingTmpl = []byte(`
|
var landingTmpl = []byte(`
|
||||||
<!doctype html>
|
<!doctype html>
|
||||||
@@ -39,8 +43,10 @@ var landingTmpl = []byte(`
|
|||||||
</html>
|
</html>
|
||||||
`)
|
`)
|
||||||
|
|
||||||
func handleLanding(reqID string, rw http.ResponseWriter, r *http.Request) {
|
// LandingHandler handles the landing page requests
|
||||||
rw.Header().Set("Content-Type", "text/html")
|
func LandingHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
rw.WriteHeader(200)
|
rw.Header().Set(httpheaders.ContentType, "text/html")
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
rw.Write(landingTmpl)
|
rw.Write(landingTmpl)
|
||||||
|
return nil
|
||||||
}
|
}
|
@@ -17,6 +17,7 @@ const (
|
|||||||
AltSvc = "Alt-Svc"
|
AltSvc = "Alt-Svc"
|
||||||
Authorization = "Authorization"
|
Authorization = "Authorization"
|
||||||
CacheControl = "Cache-Control"
|
CacheControl = "Cache-Control"
|
||||||
|
CFConnectingIP = "CF-Connecting-IP"
|
||||||
Connection = "Connection"
|
Connection = "Connection"
|
||||||
ContentDisposition = "Content-Disposition"
|
ContentDisposition = "Content-Disposition"
|
||||||
ContentEncoding = "Content-Encoding"
|
ContentEncoding = "Content-Encoding"
|
||||||
@@ -56,6 +57,7 @@ const (
|
|||||||
Vary = "Vary"
|
Vary = "Vary"
|
||||||
Via = "Via"
|
Via = "Via"
|
||||||
WwwAuthenticate = "Www-Authenticate"
|
WwwAuthenticate = "Www-Authenticate"
|
||||||
|
XAmznRequestContextHeader = "x-amzn-request-context"
|
||||||
XContentTypeOptions = "X-Content-Type-Options"
|
XContentTypeOptions = "X-Content-Type-Options"
|
||||||
XForwardedFor = "X-Forwarded-For"
|
XForwardedFor = "X-Forwarded-For"
|
||||||
XForwardedHost = "X-Forwarded-Host"
|
XForwardedHost = "X-Forwarded-Host"
|
||||||
@@ -63,6 +65,8 @@ const (
|
|||||||
XFrameOptions = "X-Frame-Options"
|
XFrameOptions = "X-Frame-Options"
|
||||||
XOriginWidth = "X-Origin-Width"
|
XOriginWidth = "X-Origin-Width"
|
||||||
XOriginHeight = "X-Origin-Height"
|
XOriginHeight = "X-Origin-Height"
|
||||||
|
XRealIP = "X-Real-IP"
|
||||||
|
XRequestID = "X-Request-ID"
|
||||||
XResultWidth = "X-Result-Width"
|
XResultWidth = "X-Result-Width"
|
||||||
XResultHeight = "X-Result-Height"
|
XResultHeight = "X-Result-Height"
|
||||||
XOriginContentLength = "X-Origin-Content-Length"
|
XOriginContentLength = "X-Origin-Content-Length"
|
||||||
|
@@ -7,6 +7,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCategory = "default"
|
||||||
|
)
|
||||||
|
|
||||||
type Option func(*Error)
|
type Option func(*Error)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@@ -16,6 +20,7 @@ type Error struct {
|
|||||||
statusCode int
|
statusCode int
|
||||||
publicMessage string
|
publicMessage string
|
||||||
shouldReport bool
|
shouldReport bool
|
||||||
|
category string
|
||||||
|
|
||||||
stack []uintptr
|
stack []uintptr
|
||||||
}
|
}
|
||||||
@@ -64,6 +69,14 @@ func (e *Error) Callers() []uintptr {
|
|||||||
return e.stack
|
return e.stack
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Error) Category() string {
|
||||||
|
if e.category == "" {
|
||||||
|
return defaultCategory
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.category
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Error) FormatStackLines() []string {
|
func (e *Error) FormatStackLines() []string {
|
||||||
lines := make([]string, len(e.stack))
|
lines := make([]string, len(e.stack))
|
||||||
|
|
||||||
@@ -141,6 +154,12 @@ func WithShouldReport(report bool) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithCategory(category string) Option {
|
||||||
|
return func(e *Error) {
|
||||||
|
e.category = category
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func callers(skip int) []uintptr {
|
func callers(skip int) []uintptr {
|
||||||
stack := make([]uintptr, 10)
|
stack := make([]uintptr, 10)
|
||||||
n := runtime.Callers(skip+2, stack)
|
n := runtime.Callers(skip+2, stack)
|
||||||
|
44
main.go
44
main.go
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/config/loadenv"
|
"github.com/imgproxy/imgproxy/v3/config/loadenv"
|
||||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||||
"github.com/imgproxy/imgproxy/v3/gliblog"
|
"github.com/imgproxy/imgproxy/v3/gliblog"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/handlers"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/logger"
|
"github.com/imgproxy/imgproxy/v3/logger"
|
||||||
"github.com/imgproxy/imgproxy/v3/memory"
|
"github.com/imgproxy/imgproxy/v3/memory"
|
||||||
@@ -23,10 +24,37 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/metrics/prometheus"
|
"github.com/imgproxy/imgproxy/v3/metrics/prometheus"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/processing"
|
"github.com/imgproxy/imgproxy/v3/processing"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
"github.com/imgproxy/imgproxy/v3/version"
|
"github.com/imgproxy/imgproxy/v3/version"
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
faviconPath = "/favicon.ico"
|
||||||
|
healthPath = "/health"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildRouter(r *server.Router) *server.Router {
|
||||||
|
r.GET("/", true, handlers.LandingHandler)
|
||||||
|
r.GET("", true, handlers.LandingHandler)
|
||||||
|
|
||||||
|
r.GET(
|
||||||
|
"/", false, handleProcessing,
|
||||||
|
r.WithSecret, r.WithCORS, r.WithPanic, r.WithReportError, r.WithMetrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
r.HEAD("/", false, r.OkHandler, r.WithCORS)
|
||||||
|
r.OPTIONS("/", false, r.OkHandler, r.WithCORS)
|
||||||
|
|
||||||
|
r.GET(faviconPath, true, r.NotFoundHandler).Silent()
|
||||||
|
r.GET(healthPath, true, handlers.HealthHandler).Silent()
|
||||||
|
if config.HealthCheckPath != "" {
|
||||||
|
r.GET(config.HealthCheckPath, true, handlers.HealthHandler).Silent()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func initialize() error {
|
func initialize() error {
|
||||||
if err := loadenv.Load(); err != nil {
|
if err := loadenv.Load(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,25 +131,21 @@ func run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
if err := prometheus.StartServer(cancel); err != nil {
|
if err := prometheus.StartServer(cancel); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := startServer(cancel)
|
cfg := server.NewConfigFromEnv()
|
||||||
|
r := server.NewRouter(cfg)
|
||||||
|
s, err := server.Start(cancel, buildRouter(r))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer shutdownServer(s)
|
defer s.Shutdown(ctx)
|
||||||
|
|
||||||
stop := make(chan os.Signal, 1)
|
<-ctx.Done()
|
||||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-stop:
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -161,7 +161,7 @@ func StartServer(cancel context.CancelFunc) error {
|
|||||||
|
|
||||||
s := http.Server{Handler: promhttp.Handler()}
|
s := http.Server{Handler: promhttp.Handler()}
|
||||||
|
|
||||||
l, err := reuseport.Listen("tcp", config.PrometheusBind)
|
l, err := reuseport.Listen("tcp", config.PrometheusBind, config.SoReuseport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Can't start Prometheus metrics server: %s", err)
|
return fmt.Errorf("Can't start Prometheus metrics server: %s", err)
|
||||||
}
|
}
|
||||||
|
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ func (p pipeline) Run(ctx context.Context, img *vips.Image, po *options.Processi
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := router.CheckTimeout(ctx); err != nil {
|
if err := server.CheckTimeout(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -12,8 +12,8 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/security"
|
"github.com/imgproxy/imgproxy/v3/security"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
"github.com/imgproxy/imgproxy/v3/svg"
|
"github.com/imgproxy/imgproxy/v3/svg"
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
)
|
)
|
||||||
@@ -173,7 +173,7 @@ func transformAnimated(ctx context.Context, img *vips.Image, po *options.Process
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = router.CheckTimeout(ctx); err != nil {
|
if err = server.CheckTimeout(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -240,7 +240,7 @@ func saveImageToFitBytes(ctx context.Context, po *options.ProcessingOptions, img
|
|||||||
}
|
}
|
||||||
imgdata.Close()
|
imgdata.Close()
|
||||||
|
|
||||||
if err := router.CheckTimeout(ctx); err != nil {
|
if err := server.CheckTimeout(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -27,8 +26,8 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/processing"
|
"github.com/imgproxy/imgproxy/v3/processing"
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/security"
|
"github.com/imgproxy/imgproxy/v3/security"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -122,17 +121,19 @@ func setCanonical(rw http.ResponseWriter, originURL string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeOriginContentLengthDebugHeader(ctx context.Context, rw http.ResponseWriter, originData imagedata.ImageData) {
|
func writeOriginContentLengthDebugHeader(rw http.ResponseWriter, originData imagedata.ImageData) error {
|
||||||
if !config.EnableDebugHeaders {
|
if !config.EnableDebugHeaders {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
size, err := originData.Size()
|
size, err := originData.Size()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
checkErr(ctx, "image_data_size", err)
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size))
|
rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size))
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
|
func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
|
||||||
@@ -146,13 +147,13 @@ func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result) {
|
|||||||
rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight))
|
rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight))
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) {
|
func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, statusCode int, resultData imagedata.ImageData, po *options.ProcessingOptions, originURL string, originData imagedata.ImageData, originHeaders http.Header) error {
|
||||||
// We read the size of the image data here, so we can set Content-Length header.
|
// We read the size of the image data here, so we can set Content-Length header.
|
||||||
// This indireclty ensures that the image data is fully read from the source, no
|
// This indireclty ensures that the image data is fully read from the source, no
|
||||||
// errors happened.
|
// errors happened.
|
||||||
resultSize, err := resultData.Size()
|
resultSize, err := resultData.Size()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
checkErr(r.Context(), "image_data_size", err)
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
contentDisposition := httpheaders.ContentDispositionValue(
|
contentDisposition := httpheaders.ContentDispositionValue(
|
||||||
@@ -183,18 +184,19 @@ func respondWithImage(reqID string, r *http.Request, rw http.ResponseWriter, sta
|
|||||||
ierr = newResponseWriteError(err)
|
ierr = newResponseWriteError(err)
|
||||||
|
|
||||||
if config.ReportIOErrors {
|
if config.ReportIOErrors {
|
||||||
sendErr(r.Context(), "IO", ierr)
|
return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true))
|
||||||
errorreport.Report(ierr, r)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
router.LogResponse(
|
server.LogResponse(
|
||||||
reqID, r, statusCode, ierr,
|
reqID, r, statusCode, ierr,
|
||||||
log.Fields{
|
log.Fields{
|
||||||
"image_url": originURL,
|
"image_url": originURL,
|
||||||
"processing_options": po,
|
"processing_options": po,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) {
|
func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, originHeaders http.Header) {
|
||||||
@@ -202,7 +204,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
|
|||||||
setVary(rw)
|
setVary(rw)
|
||||||
|
|
||||||
rw.WriteHeader(304)
|
rw.WriteHeader(304)
|
||||||
router.LogResponse(
|
server.LogResponse(
|
||||||
reqID, r, 304, nil,
|
reqID, r, 304, nil,
|
||||||
log.Fields{
|
log.Fields{
|
||||||
"image_url": originURL,
|
"image_url": originURL,
|
||||||
@@ -211,37 +213,7 @@ func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWrite
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendErr(ctx context.Context, errType string, err error) {
|
func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
send := true
|
|
||||||
|
|
||||||
if ierr, ok := err.(*ierrors.Error); ok {
|
|
||||||
switch ierr.StatusCode() {
|
|
||||||
case http.StatusServiceUnavailable:
|
|
||||||
errType = "timeout"
|
|
||||||
case 499:
|
|
||||||
// Don't need to send a "request cancelled" error
|
|
||||||
send = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if send {
|
|
||||||
metrics.SendError(ctx, errType, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendErrAndPanic(ctx context.Context, errType string, err error) {
|
|
||||||
sendErr(ctx, errType, err)
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkErr(ctx context.Context, errType string, err error) {
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sendErrAndPanic(ctx, errType, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
stats.IncRequestsInProgress()
|
stats.IncRequestsInProgress()
|
||||||
defer stats.DecRequestsInProgress()
|
defer stats.DecRequestsInProgress()
|
||||||
|
|
||||||
@@ -263,19 +235,22 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
signature = path[:signatureEnd]
|
signature = path[:signatureEnd]
|
||||||
path = path[signatureEnd:]
|
path = path[signatureEnd:]
|
||||||
} else {
|
} else {
|
||||||
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf(
|
return ierrors.Wrap(
|
||||||
http.StatusNotFound, "Invalid path: %s", path),
|
newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0,
|
||||||
|
ierrors.WithCategory(categoryPathParsing),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
path = fixPath(path)
|
path = fixPath(path)
|
||||||
|
|
||||||
if err := security.VerifySignature(signature, path); err != nil {
|
if err := security.VerifySignature(signature, path); err != nil {
|
||||||
sendErrAndPanic(ctx, "security", err)
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
|
||||||
}
|
}
|
||||||
|
|
||||||
po, imageURL, err := options.ParsePath(path, r.Header)
|
po, imageURL, err := options.ParsePath(path, r.Header)
|
||||||
checkErr(ctx, "path_parsing", err)
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing))
|
||||||
|
}
|
||||||
|
|
||||||
var imageOrigin any
|
var imageOrigin any
|
||||||
if u, uerr := url.Parse(imageURL); uerr == nil {
|
if u, uerr := url.Parse(imageURL); uerr == nil {
|
||||||
@@ -295,19 +270,20 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
metrics.SetMetadata(ctx, metricsMeta)
|
metrics.SetMetadata(ctx, metricsMeta)
|
||||||
|
|
||||||
err = security.VerifySourceURL(imageURL)
|
err = security.VerifySourceURL(imageURL)
|
||||||
checkErr(ctx, "security", err)
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity))
|
||||||
|
}
|
||||||
|
|
||||||
if po.Raw {
|
if po.Raw {
|
||||||
streamOriginImage(ctx, reqID, r, rw, po, imageURL)
|
return streamOriginImage(ctx, reqID, r, rw, po, imageURL)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SVG is a special case. Though saving to svg is not supported, SVG->SVG is.
|
// SVG is a special case. Though saving to svg is not supported, SVG->SVG is.
|
||||||
if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG {
|
if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG {
|
||||||
sendErrAndPanic(ctx, "path_parsing", newInvalidURLErrorf(
|
return ierrors.Wrap(newInvalidURLErrorf(
|
||||||
http.StatusUnprocessableEntity,
|
http.StatusUnprocessableEntity,
|
||||||
"Resulting image format is not supported: %s", po.Format,
|
"Resulting image format is not supported: %s", po.Format,
|
||||||
))
|
), 0, ierrors.WithCategory(categoryPathParsing))
|
||||||
}
|
}
|
||||||
|
|
||||||
imgRequestHeader := make(http.Header)
|
imgRequestHeader := make(http.Header)
|
||||||
@@ -339,7 +315,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The heavy part starts here, so we need to restrict worker number
|
// The heavy part starts here, so we need to restrict worker number
|
||||||
func() {
|
err = func() error {
|
||||||
defer metrics.StartQueueSegment(ctx)()
|
defer metrics.StartQueueSegment(ctx)()
|
||||||
|
|
||||||
err = processingSem.Acquire(ctx, 1)
|
err = processingSem.Acquire(ctx, 1)
|
||||||
@@ -347,12 +323,21 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
// We don't actually need to check timeout here,
|
// We don't actually need to check timeout here,
|
||||||
// but it's an easy way to check if this is an actual timeout
|
// but it's an easy way to check if this is an actual timeout
|
||||||
// or the request was canceled
|
// or the request was canceled
|
||||||
checkErr(ctx, "queue", router.CheckTimeout(ctx))
|
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||||
|
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
// We should never reach this line as err could be only ctx.Err()
|
// We should never reach this line as err could be only ctx.Err()
|
||||||
// and we've already checked for it. But beter safe than sorry
|
// and we've already checked for it. But beter safe than sorry
|
||||||
sendErrAndPanic(ctx, "queue", err)
|
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}()
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
defer processingSem.Release(1)
|
defer processingSem.Release(1)
|
||||||
|
|
||||||
stats.IncImagesInProgress()
|
stats.IncImagesInProgress()
|
||||||
@@ -375,7 +360,9 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if config.CookiePassthrough {
|
if config.CookiePassthrough {
|
||||||
downloadOpts.CookieJar, err = cookies.JarFromRequest(r)
|
downloadOpts.CookieJar, err = cookies.JarFromRequest(r)
|
||||||
checkErr(ctx, "download", err)
|
if err != nil {
|
||||||
|
return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts)
|
return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts)
|
||||||
@@ -393,26 +380,28 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers())
|
respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers())
|
||||||
return
|
return nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// This may be a request timeout error or a request cancelled error.
|
// This may be a request timeout error or a request cancelled error.
|
||||||
// Check it before moving further
|
// Check it before moving further
|
||||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||||
|
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
ierr := ierrors.Wrap(err, 0)
|
ierr := ierrors.Wrap(err, 0)
|
||||||
if config.ReportDownloadingErrors {
|
if config.ReportDownloadingErrors {
|
||||||
ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true))
|
ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true))
|
||||||
}
|
}
|
||||||
|
|
||||||
sendErr(ctx, "download", ierr)
|
|
||||||
|
|
||||||
if imagedata.FallbackImage == nil {
|
if imagedata.FallbackImage == nil {
|
||||||
panic(ierr)
|
return ierr
|
||||||
}
|
}
|
||||||
|
|
||||||
// We didn't panic, so the error is not reported.
|
// Just send error
|
||||||
// Report it now
|
metrics.SendError(ctx, categoryDownload, ierr)
|
||||||
|
|
||||||
|
// We didn't return, so we have to report error
|
||||||
if ierr.ShouldReport() {
|
if ierr.ShouldReport() {
|
||||||
errorreport.Report(ierr, r)
|
errorreport.Report(ierr, r)
|
||||||
}
|
}
|
||||||
@@ -433,27 +422,33 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||||
|
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
if config.ETagEnabled && statusCode == http.StatusOK {
|
if config.ETagEnabled && statusCode == http.StatusOK {
|
||||||
imgDataMatch, terr := etagHandler.SetActualImageData(originData, originHeaders)
|
imgDataMatch, eerr := etagHandler.SetActualImageData(originData, originHeaders)
|
||||||
if terr == nil {
|
if eerr != nil && config.ReportIOErrors {
|
||||||
rw.Header().Set("ETag", etagHandler.GenerateActualETag())
|
return ierrors.Wrap(eerr, 0, ierrors.WithCategory(categoryIO))
|
||||||
|
}
|
||||||
|
|
||||||
if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
|
rw.Header().Set("ETag", etagHandler.GenerateActualETag())
|
||||||
respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
|
|
||||||
return
|
if imgDataMatch && etagHandler.ProcessingOptionsMatch() {
|
||||||
}
|
respondWithNotModified(reqID, r, rw, po, imageURL, originHeaders)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||||
|
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
if !vips.SupportsLoad(originData.Format()) {
|
if !vips.SupportsLoad(originData.Format()) {
|
||||||
sendErrAndPanic(ctx, "processing", newInvalidURLErrorf(
|
return ierrors.Wrap(newInvalidURLErrorf(
|
||||||
http.StatusUnprocessableEntity,
|
http.StatusUnprocessableEntity,
|
||||||
"Source image format is not supported: %s", originData.Format(),
|
"Source image format is not supported: %s", originData.Format(),
|
||||||
))
|
), 0, ierrors.WithCategory(categoryProcessing))
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := func() (*processing.Result, error) {
|
result, err := func() (*processing.Result, error) {
|
||||||
@@ -468,18 +463,31 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|||||||
defer result.OutData.Close()
|
defer result.OutData.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
// First, check if the processing error wasn't caused by an image data error
|
||||||
// First, check if the processing error wasn't caused by an image data error
|
if originData.Error() != nil {
|
||||||
checkErr(ctx, "download", originData.Error())
|
return ierrors.Wrap(originData.Error(), 0, ierrors.WithCategory(categoryDownload))
|
||||||
|
|
||||||
// If it wasn't, than it was a processing error
|
|
||||||
sendErrAndPanic(ctx, "processing", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
checkErr(ctx, "timeout", router.CheckTimeout(ctx))
|
// If it wasn't, than it was a processing error
|
||||||
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing))
|
||||||
|
}
|
||||||
|
|
||||||
|
if terr := server.CheckTimeout(ctx); terr != nil {
|
||||||
|
return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
writeDebugHeaders(rw, result)
|
writeDebugHeaders(rw, result)
|
||||||
writeOriginContentLengthDebugHeader(ctx, rw, originData)
|
|
||||||
|
|
||||||
respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders)
|
err = writeOriginContentLengthDebugHeader(rw, originData)
|
||||||
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, originData, originHeaders)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagetype"
|
"github.com/imgproxy/imgproxy/v3/imagetype"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
"github.com/imgproxy/imgproxy/v3/svg"
|
"github.com/imgproxy/imgproxy/v3/svg"
|
||||||
"github.com/imgproxy/imgproxy/v3/testutil"
|
"github.com/imgproxy/imgproxy/v3/testutil"
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
"github.com/imgproxy/imgproxy/v3/vips"
|
||||||
@@ -31,7 +31,7 @@ import (
|
|||||||
type ProcessingHandlerTestSuite struct {
|
type ProcessingHandlerTestSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
|
|
||||||
router *router.Router
|
router *server.Router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProcessingHandlerTestSuite) SetupSuite() {
|
func (s *ProcessingHandlerTestSuite) SetupSuite() {
|
||||||
@@ -48,7 +48,7 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() {
|
|||||||
|
|
||||||
logrus.SetOutput(io.Discard)
|
logrus.SetOutput(io.Discard)
|
||||||
|
|
||||||
s.router = buildRouter()
|
s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProcessingHandlerTestSuite) TeardownSuite() {
|
func (s *ProcessingHandlerTestSuite) TeardownSuite() {
|
||||||
|
@@ -7,12 +7,10 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Listen(network, address string) (net.Listener, error) {
|
func Listen(network, address string, reuse bool) (net.Listener, error) {
|
||||||
if config.SoReuseport {
|
if reuse {
|
||||||
log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version")
|
log.Warning("SO_REUSEPORT support is not implemented for your OS or Go version")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -10,12 +10,10 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Listen(network, address string) (net.Listener, error) {
|
func Listen(network, address string, reuse bool) (net.Listener, error) {
|
||||||
if !config.SoReuseport {
|
if !reuse {
|
||||||
return net.Listen(network, address)
|
return net.Listen(network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
174
router/router.go
174
router/router.go
@@ -1,174 +0,0 @@
|
|||||||
package router
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
nanoid "github.com/matoous/go-nanoid/v2"
|
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
xRequestIDHeader = "X-Request-ID"
|
|
||||||
xAmznRequestContextHeader = "x-amzn-request-context"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
|
||||||
)
|
|
||||||
|
|
||||||
type RouteHandler func(string, http.ResponseWriter, *http.Request)
|
|
||||||
|
|
||||||
type route struct {
|
|
||||||
Method string
|
|
||||||
Prefix string
|
|
||||||
Handler RouteHandler
|
|
||||||
Exact bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Router struct {
|
|
||||||
prefix string
|
|
||||||
healthRoutes []string
|
|
||||||
faviconRoute string
|
|
||||||
|
|
||||||
Routes []*route
|
|
||||||
HealthHandler RouteHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *route) isMatch(req *http.Request) bool {
|
|
||||||
if r.Method != req.Method {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Exact {
|
|
||||||
return req.URL.Path == r.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.HasPrefix(req.URL.Path, r.Prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(prefix string) *Router {
|
|
||||||
healthRoutes := []string{prefix + "/health"}
|
|
||||||
if len(config.HealthCheckPath) > 0 {
|
|
||||||
healthRoutes = append(healthRoutes, prefix+config.HealthCheckPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Router{
|
|
||||||
prefix: prefix,
|
|
||||||
healthRoutes: healthRoutes,
|
|
||||||
faviconRoute: prefix + "/favicon.ico",
|
|
||||||
Routes: make([]*route, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) Add(method, prefix string, handler RouteHandler, exact bool) {
|
|
||||||
// Don't add routes with empty prefix
|
|
||||||
if len(r.prefix+prefix) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Routes = append(
|
|
||||||
r.Routes,
|
|
||||||
&route{Method: method, Prefix: r.prefix + prefix, Handler: handler, Exact: exact},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) GET(prefix string, handler RouteHandler, exact bool) {
|
|
||||||
r.Add(http.MethodGet, prefix, handler, exact)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) OPTIONS(prefix string, handler RouteHandler, exact bool) {
|
|
||||||
r.Add(http.MethodOptions, prefix, handler, exact)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) HEAD(prefix string, handler RouteHandler, exact bool) {
|
|
||||||
r.Add(http.MethodHead, prefix, handler, exact)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
req, timeoutCancel := startRequestTimer(req)
|
|
||||||
defer timeoutCancel()
|
|
||||||
|
|
||||||
rw = newTimeoutResponse(rw)
|
|
||||||
|
|
||||||
reqID := req.Header.Get(xRequestIDHeader)
|
|
||||||
|
|
||||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
|
||||||
if lambdaContextVal := req.Header.Get(xAmznRequestContextHeader); len(lambdaContextVal) > 0 {
|
|
||||||
var lambdaContext struct {
|
|
||||||
RequestID string `json:"requestId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
|
|
||||||
if err == nil && len(lambdaContext.RequestID) > 0 {
|
|
||||||
reqID = lambdaContext.RequestID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
|
||||||
reqID, _ = nanoid.New()
|
|
||||||
}
|
|
||||||
|
|
||||||
rw.Header().Set("Server", "imgproxy")
|
|
||||||
rw.Header().Set(xRequestIDHeader, reqID)
|
|
||||||
|
|
||||||
if req.Method == http.MethodGet {
|
|
||||||
if r.HealthHandler != nil {
|
|
||||||
for _, healthRoute := range r.healthRoutes {
|
|
||||||
if req.URL.Path == healthRoute {
|
|
||||||
r.HealthHandler(reqID, rw, req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.URL.Path == r.faviconRoute {
|
|
||||||
// TODO: Add a real favicon maybe?
|
|
||||||
rw.Header().Set("Content-Type", "text/plain")
|
|
||||||
rw.WriteHeader(404)
|
|
||||||
// Write a single byte to make AWS Lambda happy
|
|
||||||
rw.Write([]byte{' '})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip := req.Header.Get("CF-Connecting-IP"); len(ip) != 0 {
|
|
||||||
replaceRemoteAddr(req, ip)
|
|
||||||
} else if ip := req.Header.Get("X-Forwarded-For"); len(ip) != 0 {
|
|
||||||
if index := strings.Index(ip, ","); index > 0 {
|
|
||||||
ip = ip[:index]
|
|
||||||
}
|
|
||||||
replaceRemoteAddr(req, ip)
|
|
||||||
} else if ip := req.Header.Get("X-Real-IP"); len(ip) != 0 {
|
|
||||||
replaceRemoteAddr(req, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
LogRequest(reqID, req)
|
|
||||||
|
|
||||||
for _, rr := range r.Routes {
|
|
||||||
if rr.isMatch(req) {
|
|
||||||
rr.Handler(reqID, rw, req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LogResponse(reqID, req, 404, newRouteNotDefinedError(req.URL.Path))
|
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "text/plain")
|
|
||||||
rw.WriteHeader(404)
|
|
||||||
rw.Write([]byte{' '})
|
|
||||||
}
|
|
||||||
|
|
||||||
func replaceRemoteAddr(req *http.Request, ip string) {
|
|
||||||
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
port = "80"
|
|
||||||
}
|
|
||||||
|
|
||||||
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
|
|
||||||
}
|
|
@@ -1,43 +0,0 @@
|
|||||||
package router
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
type timeoutResponse struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
controller *http.ResponseController
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTimeoutResponse(rw http.ResponseWriter) http.ResponseWriter {
|
|
||||||
return &timeoutResponse{
|
|
||||||
ResponseWriter: rw,
|
|
||||||
controller: http.NewResponseController(rw),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *timeoutResponse) WriteHeader(statusCode int) {
|
|
||||||
rw.withWriteDeadline(func() {
|
|
||||||
rw.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
rw.withWriteDeadline(func() {
|
|
||||||
n, err = rw.ResponseWriter.Write(b)
|
|
||||||
})
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
|
||||||
rw.controller.SetWriteDeadline(time.Now().Add(time.Duration(config.WriteResponseTimeout) * time.Second))
|
|
||||||
defer rw.controller.SetWriteDeadline(time.Time{})
|
|
||||||
f()
|
|
||||||
}
|
|
204
server.go
204
server.go
@@ -1,204 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/subtle"
|
|
||||||
"fmt"
|
|
||||||
golog "log"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/net/netutil"
|
|
||||||
|
|
||||||
"github.com/imgproxy/imgproxy/v3/config"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/errorreport"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/metrics"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/reuseport"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
|
||||||
"github.com/imgproxy/imgproxy/v3/vips"
|
|
||||||
)
|
|
||||||
|
|
||||||
var imgproxyIsRunningMsg = []byte("imgproxy is running")
|
|
||||||
|
|
||||||
func buildRouter() *router.Router {
|
|
||||||
r := router.New(config.PathPrefix)
|
|
||||||
|
|
||||||
r.GET("/", handleLanding, true)
|
|
||||||
r.GET("", handleLanding, true)
|
|
||||||
|
|
||||||
r.GET("/", withMetrics(withPanicHandler(withCORS(withSecret(handleProcessing)))), false)
|
|
||||||
|
|
||||||
r.HEAD("/", withCORS(handleHead), false)
|
|
||||||
r.OPTIONS("/", withCORS(handleHead), false)
|
|
||||||
|
|
||||||
r.HealthHandler = handleHealth
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func startServer(cancel context.CancelFunc) (*http.Server, error) {
|
|
||||||
l, err := reuseport.Listen(config.Network, config.Bind)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("can't start server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.MaxClients > 0 {
|
|
||||||
l = netutil.LimitListener(l, config.MaxClients)
|
|
||||||
}
|
|
||||||
|
|
||||||
errLogger := golog.New(
|
|
||||||
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
|
|
||||||
"", 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
s := &http.Server{
|
|
||||||
Handler: buildRouter(),
|
|
||||||
ReadTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
|
|
||||||
MaxHeaderBytes: 1 << 20,
|
|
||||||
ErrorLog: errLogger,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.KeepAliveTimeout > 0 {
|
|
||||||
s.IdleTimeout = time.Duration(config.KeepAliveTimeout) * time.Second
|
|
||||||
} else {
|
|
||||||
s.SetKeepAlivesEnabled(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Infof("Starting server at %s", config.Bind)
|
|
||||||
if err := s.Serve(l); err != nil && err != http.ErrServerClosed {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func shutdownServer(s *http.Server) {
|
|
||||||
log.Info("Shutting down the server...")
|
|
||||||
|
|
||||||
ctx, close := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
s.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func withMetrics(h router.RouteHandler) router.RouteHandler {
|
|
||||||
if !metrics.Enabled() {
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, metricsCancel, rw := metrics.StartRequest(r.Context(), rw, r)
|
|
||||||
defer metricsCancel()
|
|
||||||
|
|
||||||
h(reqID, rw, r.WithContext(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func withCORS(h router.RouteHandler) router.RouteHandler {
|
|
||||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
if len(config.AllowOrigin) > 0 {
|
|
||||||
rw.Header().Set("Access-Control-Allow-Origin", config.AllowOrigin)
|
|
||||||
rw.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
h(reqID, rw, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func withSecret(h router.RouteHandler) router.RouteHandler {
|
|
||||||
if len(config.Secret) == 0 {
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
authHeader := []byte(fmt.Sprintf("Bearer %s", config.Secret))
|
|
||||||
|
|
||||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
if subtle.ConstantTimeCompare([]byte(r.Header.Get("Authorization")), authHeader) == 1 {
|
|
||||||
h(reqID, rw, r)
|
|
||||||
} else {
|
|
||||||
panic(newInvalidSecretError())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func withPanicHandler(h router.RouteHandler) router.RouteHandler {
|
|
||||||
return func(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := errorreport.StartRequest(r)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
errorreport.SetMetadata(r, "Request ID", reqID)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if rerr := recover(); rerr != nil {
|
|
||||||
if rerr == http.ErrAbortHandler {
|
|
||||||
panic(rerr)
|
|
||||||
}
|
|
||||||
|
|
||||||
err, ok := rerr.(error)
|
|
||||||
if !ok {
|
|
||||||
panic(rerr)
|
|
||||||
}
|
|
||||||
|
|
||||||
ierr := ierrors.Wrap(err, 0)
|
|
||||||
|
|
||||||
if ierr.ShouldReport() {
|
|
||||||
errorreport.Report(err, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
router.LogResponse(reqID, r, ierr.StatusCode(), ierr)
|
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "text/plain")
|
|
||||||
rw.WriteHeader(ierr.StatusCode())
|
|
||||||
|
|
||||||
if config.DevelopmentErrorsMode {
|
|
||||||
rw.Write([]byte(ierr.Error()))
|
|
||||||
} else {
|
|
||||||
rw.Write([]byte(ierr.PublicMessage()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
h(reqID, rw, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleHealth(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
var (
|
|
||||||
status int
|
|
||||||
msg []byte
|
|
||||||
ierr *ierrors.Error
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := vips.Health(); err == nil {
|
|
||||||
status = http.StatusOK
|
|
||||||
msg = imgproxyIsRunningMsg
|
|
||||||
} else {
|
|
||||||
status = http.StatusInternalServerError
|
|
||||||
msg = []byte("Error")
|
|
||||||
ierr = ierrors.Wrap(err, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(msg) == 0 {
|
|
||||||
msg = []byte{' '}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log response only if something went wrong
|
|
||||||
if ierr != nil {
|
|
||||||
router.LogResponse(reqID, r, status, ierr)
|
|
||||||
}
|
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "text/plain")
|
|
||||||
rw.Header().Set("Cache-Control", "no-cache")
|
|
||||||
rw.WriteHeader(status)
|
|
||||||
rw.Write(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleHead(reqID string, rw http.ResponseWriter, r *http.Request) {
|
|
||||||
router.LogResponse(reqID, r, 200, nil)
|
|
||||||
rw.WriteHeader(200)
|
|
||||||
}
|
|
49
server/config.go
Normal file
49
server/config.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// gracefulTimeout represents graceful shutdown timeout
|
||||||
|
gracefulTimeout = time.Duration(5 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents HTTP server config
|
||||||
|
type Config struct {
|
||||||
|
Listen string // Address to listen on
|
||||||
|
Network string // Network type (tcp, unix)
|
||||||
|
Bind string // Bind address
|
||||||
|
PathPrefix string // Path prefix for the server
|
||||||
|
MaxClients int // Maximum number of concurrent clients
|
||||||
|
ReadRequestTimeout time.Duration // Timeout for reading requests
|
||||||
|
KeepAliveTimeout time.Duration // Timeout for keep-alive connections
|
||||||
|
GracefulTimeout time.Duration // Timeout for graceful shutdown
|
||||||
|
CORSAllowOrigin string // CORS allowed origin
|
||||||
|
Secret string // Secret for authorization
|
||||||
|
DevelopmentErrorsMode bool // Enable development mode for detailed error messages
|
||||||
|
SocketReusePort bool // Enable SO_REUSEPORT socket option
|
||||||
|
HealthCheckPath string // Health check path from config
|
||||||
|
WriteResponseTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigFromEnv creates a new Config instance from environment variables
|
||||||
|
func NewConfigFromEnv() *Config {
|
||||||
|
return &Config{
|
||||||
|
Network: config.Network,
|
||||||
|
Bind: config.Bind,
|
||||||
|
PathPrefix: config.PathPrefix,
|
||||||
|
MaxClients: config.MaxClients,
|
||||||
|
ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second,
|
||||||
|
KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second,
|
||||||
|
GracefulTimeout: gracefulTimeout,
|
||||||
|
CORSAllowOrigin: config.AllowOrigin,
|
||||||
|
Secret: config.Secret,
|
||||||
|
DevelopmentErrorsMode: config.DevelopmentErrorsMode,
|
||||||
|
SocketReusePort: config.SoReuseport,
|
||||||
|
HealthCheckPath: config.HealthCheckPath,
|
||||||
|
WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second,
|
||||||
|
}
|
||||||
|
}
|
@@ -1,6 +1,7 @@
|
|||||||
package router
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,6 +13,7 @@ type (
|
|||||||
RouteNotDefinedError string
|
RouteNotDefinedError string
|
||||||
RequestCancelledError string
|
RequestCancelledError string
|
||||||
RequestTimeoutError string
|
RequestTimeoutError string
|
||||||
|
InvalidSecretError struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newRouteNotDefinedError(path string) *ierrors.Error {
|
func newRouteNotDefinedError(path string) *ierrors.Error {
|
||||||
@@ -33,11 +35,16 @@ func newRequestCancelledError(after time.Duration) *ierrors.Error {
|
|||||||
ierrors.WithStatusCode(499),
|
ierrors.WithStatusCode(499),
|
||||||
ierrors.WithPublicMessage("Cancelled"),
|
ierrors.WithPublicMessage("Cancelled"),
|
||||||
ierrors.WithShouldReport(false),
|
ierrors.WithShouldReport(false),
|
||||||
|
ierrors.WithCategory(categoryTimeout),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e RequestCancelledError) Error() string { return string(e) }
|
func (e RequestCancelledError) Error() string { return string(e) }
|
||||||
|
|
||||||
|
func (e RequestCancelledError) Unwrap() error {
|
||||||
|
return context.Canceled
|
||||||
|
}
|
||||||
|
|
||||||
func newRequestTimeoutError(after time.Duration) *ierrors.Error {
|
func newRequestTimeoutError(after time.Duration) *ierrors.Error {
|
||||||
return ierrors.Wrap(
|
return ierrors.Wrap(
|
||||||
RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)),
|
RequestTimeoutError(fmt.Sprintf("Request was timed out after %v", after)),
|
||||||
@@ -45,7 +52,24 @@ func newRequestTimeoutError(after time.Duration) *ierrors.Error {
|
|||||||
ierrors.WithStatusCode(http.StatusServiceUnavailable),
|
ierrors.WithStatusCode(http.StatusServiceUnavailable),
|
||||||
ierrors.WithPublicMessage("Gateway Timeout"),
|
ierrors.WithPublicMessage("Gateway Timeout"),
|
||||||
ierrors.WithShouldReport(false),
|
ierrors.WithShouldReport(false),
|
||||||
|
ierrors.WithCategory(categoryTimeout),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e RequestTimeoutError) Error() string { return string(e) }
|
func (e RequestTimeoutError) Error() string { return string(e) }
|
||||||
|
|
||||||
|
func (e RequestTimeoutError) Unwrap() error {
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInvalidSecretError() error {
|
||||||
|
return ierrors.Wrap(
|
||||||
|
InvalidSecretError{},
|
||||||
|
1,
|
||||||
|
ierrors.WithStatusCode(http.StatusForbidden),
|
||||||
|
ierrors.WithPublicMessage("Forbidden"),
|
||||||
|
ierrors.WithShouldReport(false),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e InvalidSecretError) Error() string { return "Invalid secret" }
|
@@ -1,4 +1,4 @@
|
|||||||
package router
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
@@ -59,6 +59,6 @@ func LogResponse(reqID string, r *http.Request, status int, err *ierrors.Error,
|
|||||||
|
|
||||||
log.WithFields(fields).Logf(
|
log.WithFields(fields).Logf(
|
||||||
level,
|
level,
|
||||||
"Completed in %s %s", ctxTime(r.Context()), r.RequestURI,
|
"Completed in %s %s", requestStartedAt(r.Context()), r.RequestURI,
|
||||||
)
|
)
|
||||||
}
|
}
|
145
server/middlewares.go
Normal file
145
server/middlewares.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/errorreport"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
categoryTimeout = "timeout"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithMetrics wraps RouteHandler with metrics handling.
|
||||||
|
func (r *Router) WithMetrics(h RouteHandler) RouteHandler {
|
||||||
|
if !metrics.Enabled() {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
ctx, metricsCancel, rw := metrics.StartRequest(req.Context(), rw, req)
|
||||||
|
defer metricsCancel()
|
||||||
|
|
||||||
|
return h(reqID, rw, req.WithContext(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCORS wraps RouteHandler with CORS handling
|
||||||
|
func (r *Router) WithCORS(h RouteHandler) RouteHandler {
|
||||||
|
if len(r.config.CORSAllowOrigin) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.Header().Set(httpheaders.AccessControlAllowOrigin, r.config.CORSAllowOrigin)
|
||||||
|
rw.Header().Set(httpheaders.AccessControlAllowMethods, "GET, OPTIONS")
|
||||||
|
|
||||||
|
return h(reqID, rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSecret wraps RouteHandler with secret handling
|
||||||
|
func (r *Router) WithSecret(h RouteHandler) RouteHandler {
|
||||||
|
if len(r.config.Secret) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader := fmt.Appendf(nil, "Bearer %s", r.config.Secret)
|
||||||
|
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(req.Header.Get(httpheaders.Authorization)), authHeader) == 1 {
|
||||||
|
return h(reqID, rw, req)
|
||||||
|
} else {
|
||||||
|
return newInvalidSecretError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPanic recovers panic and converts it to normal error
|
||||||
|
func (r *Router) WithPanic(h RouteHandler) RouteHandler {
|
||||||
|
return func(reqID string, rw http.ResponseWriter, r *http.Request) (retErr error) {
|
||||||
|
defer func() {
|
||||||
|
// try to recover from panic
|
||||||
|
rerr := recover()
|
||||||
|
if rerr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// abort handler is an exception of net/http, we should simply repanic it.
|
||||||
|
// it will supress the stack trace
|
||||||
|
if rerr == http.ErrAbortHandler {
|
||||||
|
panic(rerr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// let's recover error value from panic if it has panicked with error
|
||||||
|
err, ok := rerr.(error)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("panic: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retErr = err
|
||||||
|
}()
|
||||||
|
|
||||||
|
return h(reqID, rw, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReportError handles error reporting.
|
||||||
|
// It should be placed after `WithMetrics`, but before `WithPanic`.
|
||||||
|
func (r *Router) WithReportError(h RouteHandler) RouteHandler {
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
// Open the error context
|
||||||
|
ctx := errorreport.StartRequest(req)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
errorreport.SetMetadata(req, "Request ID", reqID)
|
||||||
|
|
||||||
|
// Call the underlying handler passing the context downwards
|
||||||
|
err := h(reqID, rw, req)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap a resulting error into ierrors.Error
|
||||||
|
ierr := ierrors.Wrap(err, 0)
|
||||||
|
|
||||||
|
// Get the error category
|
||||||
|
errCat := ierr.Category()
|
||||||
|
|
||||||
|
// Exception: any context.DeadlineExceeded error is timeout
|
||||||
|
if errors.Is(ierr, context.DeadlineExceeded) {
|
||||||
|
errCat = categoryTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do not need to send any canceled context
|
||||||
|
if !errors.Is(ierr, context.Canceled) {
|
||||||
|
metrics.SendError(ctx, errCat, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report error to error collectors
|
||||||
|
if ierr.ShouldReport() {
|
||||||
|
errorreport.Report(ierr, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log response and format the error output
|
||||||
|
LogResponse(reqID, req, ierr.StatusCode(), ierr)
|
||||||
|
|
||||||
|
// Error message: either is public message or full development error
|
||||||
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
|
rw.WriteHeader(ierr.StatusCode())
|
||||||
|
|
||||||
|
if r.config.DevelopmentErrorsMode {
|
||||||
|
rw.Write([]byte(ierr.Error()))
|
||||||
|
} else {
|
||||||
|
rw.Write([]byte(ierr.PublicMessage()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
209
server/router.go
Normal file
209
server/router.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
nanoid "github.com/matoous/go-nanoid/v2"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultServerName is the default name of the server
|
||||||
|
defaultServerName = "imgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// requestIDRe is a regular expression for validating request IDs
|
||||||
|
requestIDRe = regexp.MustCompile(`^[A-Za-z0-9_\-]+$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// RouteHandler is a function that handles HTTP requests.
|
||||||
|
type RouteHandler func(string, http.ResponseWriter, *http.Request) error
|
||||||
|
|
||||||
|
// Middleware is a function that wraps a RouteHandler with additional functionality.
|
||||||
|
type Middleware func(next RouteHandler) RouteHandler
|
||||||
|
|
||||||
|
// route represents a single route in the router.
|
||||||
|
type route struct {
|
||||||
|
method string // method is the HTTP method for a route
|
||||||
|
path string // path represents a route path
|
||||||
|
exact bool // exact means that path must match exactly, otherwise any prefixed matches
|
||||||
|
handler RouteHandler // handler is the function that handles the route
|
||||||
|
silent bool // Silent route (no logs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router is responsible for routing HTTP requests
|
||||||
|
type Router struct {
|
||||||
|
// config represents the server configuration
|
||||||
|
config *Config
|
||||||
|
|
||||||
|
// routes is the collection of all routes
|
||||||
|
routes []*route
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a new Router instance
|
||||||
|
func NewRouter(config *Config) *Router {
|
||||||
|
return &Router{config: config}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds an abitary route to the router
|
||||||
|
func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||||
|
for _, m := range middlewares {
|
||||||
|
handler = m(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact}
|
||||||
|
|
||||||
|
r.routes = append(
|
||||||
|
r.routes,
|
||||||
|
route,
|
||||||
|
)
|
||||||
|
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET adds GET route
|
||||||
|
func (r *Router) GET(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||||
|
return r.add(http.MethodGet, prefix, exact, handler, middlewares...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPTIONS adds OPTIONS route
|
||||||
|
func (r *Router) OPTIONS(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||||
|
return r.add(http.MethodOptions, prefix, exact, handler, middlewares...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HEAD adds HEAD route
|
||||||
|
func (r *Router) HEAD(prefix string, exact bool, handler RouteHandler, middlewares ...Middleware) *route {
|
||||||
|
return r.add(http.MethodHead, prefix, exact, handler, middlewares...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP serves routes
|
||||||
|
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// Attach timer to the context
|
||||||
|
req, timeoutCancel := startRequestTimer(req)
|
||||||
|
defer timeoutCancel()
|
||||||
|
|
||||||
|
// Create the response writer which times out on write
|
||||||
|
rw = newTimeoutResponse(rw, r.config.WriteResponseTimeout)
|
||||||
|
|
||||||
|
// Get/create request ID
|
||||||
|
reqID := r.getRequestID(req)
|
||||||
|
|
||||||
|
// Replace request IP from headers
|
||||||
|
r.replaceRemoteAddr(req)
|
||||||
|
|
||||||
|
rw.Header().Set(httpheaders.Server, defaultServerName)
|
||||||
|
rw.Header().Set(httpheaders.XRequestID, reqID)
|
||||||
|
|
||||||
|
for _, rr := range r.routes {
|
||||||
|
if rr.isMatch(req) {
|
||||||
|
if !rr.silent {
|
||||||
|
LogRequest(reqID, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr.handler(reqID, rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Means that we have not found matching route
|
||||||
|
LogRequest(reqID, req)
|
||||||
|
LogResponse(reqID, req, http.StatusNotFound, newRouteNotDefinedError(req.URL.Path))
|
||||||
|
r.NotFoundHandler(reqID, rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotFoundHandler is default 404 handler
|
||||||
|
func (r *Router) NotFoundHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
|
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OkHandler is a default 200 OK handler
|
||||||
|
func (r *Router) OkHandler(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.Header().Set(httpheaders.ContentType, "text/plain")
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
rw.Write([]byte{' '}) // Write a single byte to make AWS Lambda happy
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRequestID tries to read request id from headers or from lambda
|
||||||
|
// context or generates a new one if nothing found.
|
||||||
|
func (r *Router) getRequestID(req *http.Request) string {
|
||||||
|
// Get request ID from headers (if any)
|
||||||
|
reqID := req.Header.Get(httpheaders.XRequestID)
|
||||||
|
|
||||||
|
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||||
|
lambdaContextVal := req.Header.Get(httpheaders.XAmznRequestContextHeader)
|
||||||
|
|
||||||
|
if len(lambdaContextVal) > 0 {
|
||||||
|
var lambdaContext struct {
|
||||||
|
RequestID string `json:"requestId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(lambdaContextVal), &lambdaContext)
|
||||||
|
if err == nil && len(lambdaContext.RequestID) > 0 {
|
||||||
|
reqID = lambdaContext.RequestID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reqID) == 0 || !requestIDRe.MatchString(reqID) {
|
||||||
|
reqID, _ = nanoid.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
return reqID
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceRemoteAddr rewrites the req.RemoteAddr property from request headers
|
||||||
|
func (r *Router) replaceRemoteAddr(req *http.Request) {
|
||||||
|
cfConnectingIP := req.Header.Get(httpheaders.CFConnectingIP)
|
||||||
|
xForwardedFor := req.Header.Get(httpheaders.XForwardedFor)
|
||||||
|
xRealIP := req.Header.Get(httpheaders.XRealIP)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(cfConnectingIP) > 0:
|
||||||
|
replaceRemoteAddr(req, cfConnectingIP)
|
||||||
|
case len(xForwardedFor) > 0:
|
||||||
|
if index := strings.Index(xForwardedFor, ","); index > 0 {
|
||||||
|
xForwardedFor = xForwardedFor[:index]
|
||||||
|
}
|
||||||
|
replaceRemoteAddr(req, xForwardedFor)
|
||||||
|
case len(xRealIP) > 0:
|
||||||
|
replaceRemoteAddr(req, xRealIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceRemoteAddr sets the req.RemoteAddr for request
|
||||||
|
func replaceRemoteAddr(req *http.Request, ip string) {
|
||||||
|
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
port = "80"
|
||||||
|
}
|
||||||
|
|
||||||
|
req.RemoteAddr = net.JoinHostPort(strings.TrimSpace(ip), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMatch checks that a request matches route
|
||||||
|
func (r *route) isMatch(req *http.Request) bool {
|
||||||
|
methodMatches := r.method == req.Method
|
||||||
|
notExactPathMathes := !r.exact && strings.HasPrefix(req.URL.Path, r.path)
|
||||||
|
exactPathMatches := r.exact && req.URL.Path == r.path
|
||||||
|
|
||||||
|
return methodMatches && (notExactPathMathes || exactPathMatches)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Silent sets Silent flag which supresses logs to true. We do not need to log
|
||||||
|
// requests like /health of /favicon.ico
|
||||||
|
func (r *route) Silent() *route {
|
||||||
|
r.silent = true
|
||||||
|
return r
|
||||||
|
}
|
296
server/router_test.go
Normal file
296
server/router_test.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RouterTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
router *Router
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RouterTestSuite) SetupTest() {
|
||||||
|
c := NewConfigFromEnv()
|
||||||
|
c.PathPrefix = "/api"
|
||||||
|
s.router = NewRouter(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(RouterTestSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPMethods tests route methods registration and HTTP requests
|
||||||
|
func (s *RouterTestSuite) TestHTTPMethods() {
|
||||||
|
var capturedMethod string
|
||||||
|
var capturedPath string
|
||||||
|
|
||||||
|
getHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
capturedMethod = req.Method
|
||||||
|
capturedPath = req.URL.Path
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
rw.Write([]byte("GET response"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
optionsHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
capturedMethod = req.Method
|
||||||
|
capturedPath = req.URL.Path
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
rw.Write([]byte("OPTIONS response"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
headHandler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
capturedMethod = req.Method
|
||||||
|
capturedPath = req.URL.Path
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register routes with different configurations
|
||||||
|
s.router.GET("/get-test", true, getHandler) // exact match
|
||||||
|
s.router.OPTIONS("/options-test", false, optionsHandler) // prefix match
|
||||||
|
s.router.HEAD("/head-test", true, headHandler) // exact match
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requestMethod string
|
||||||
|
requestPath string
|
||||||
|
expectedBody string
|
||||||
|
expectedPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "GET",
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/api/get-test",
|
||||||
|
expectedBody: "GET response",
|
||||||
|
expectedPath: "/api/get-test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONS",
|
||||||
|
requestMethod: http.MethodOptions,
|
||||||
|
requestPath: "/api/options-test",
|
||||||
|
expectedBody: "OPTIONS response",
|
||||||
|
expectedPath: "/api/options-test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONSPrefixed",
|
||||||
|
requestMethod: http.MethodOptions,
|
||||||
|
requestPath: "/api/options-test/sub",
|
||||||
|
expectedBody: "OPTIONS response",
|
||||||
|
expectedPath: "/api/options-test/sub",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HEAD",
|
||||||
|
requestMethod: http.MethodHead,
|
||||||
|
requestPath: "/api/head-test",
|
||||||
|
expectedBody: "",
|
||||||
|
expectedPath: "/api/head-test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
req := httptest.NewRequest(tt.requestMethod, tt.requestPath, nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
s.Require().Equal(tt.expectedBody, rw.Body.String())
|
||||||
|
s.Require().Equal(tt.requestMethod, capturedMethod)
|
||||||
|
s.Require().Equal(tt.expectedPath, capturedPath)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddlewareOrder checks middleware ordering and functionality
|
||||||
|
func (s *RouterTestSuite) TestMiddlewareOrder() {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
middleware1 := func(next RouteHandler) RouteHandler {
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
order = append(order, "middleware1")
|
||||||
|
return next(reqID, rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
middleware2 := func(next RouteHandler) RouteHandler {
|
||||||
|
return func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
order = append(order, "middleware2")
|
||||||
|
return next(reqID, rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
order = append(order, "handler")
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/test", true, handler, middleware2, middleware1)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
// Middleware should execute in the order they are passed (first added first)
|
||||||
|
s.Require().Equal([]string{"middleware1", "middleware2", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServeHTTP tests ServeHTTP method
|
||||||
|
func (s *RouterTestSuite) TestServeHTTP() {
|
||||||
|
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.Header().Set("Custom-Header", "test-value")
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
rw.Write([]byte("success"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/test", true, handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
s.Require().Equal(200, rw.Code)
|
||||||
|
s.Require().Equal("success", rw.Body.String())
|
||||||
|
s.Require().Equal("test-value", rw.Header().Get("Custom-Header"))
|
||||||
|
s.Require().Equal(defaultServerName, rw.Header().Get(httpheaders.Server))
|
||||||
|
s.Require().NotEmpty(rw.Header().Get(httpheaders.XRequestID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequestID checks request ID generation and validation
|
||||||
|
func (s *RouterTestSuite) TestRequestID() {
|
||||||
|
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/test", true, handler)
|
||||||
|
|
||||||
|
// Test request ID passthrough (if present)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.Header.Set(httpheaders.XRequestID, "valid-id-123")
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
s.Require().Equal("valid-id-123", rw.Header().Get(httpheaders.XRequestID))
|
||||||
|
|
||||||
|
// Test invalid request ID (should generate a new one)
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.Header.Set(httpheaders.XRequestID, "invalid id with spaces!")
|
||||||
|
rw = httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
generatedID := rw.Header().Get(httpheaders.XRequestID)
|
||||||
|
s.Require().NotEqual("invalid id with spaces!", generatedID)
|
||||||
|
s.Require().NotEmpty(generatedID)
|
||||||
|
|
||||||
|
// Test no request ID (should generate a new one)
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
rw = httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
generatedID = rw.Header().Get(httpheaders.XRequestID)
|
||||||
|
s.Require().NotEmpty(generatedID)
|
||||||
|
s.Require().Regexp(`^[A-Za-z0-9_\-]+$`, generatedID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLambdaRequestIDExtraction checks AWS lambda request id extraction
|
||||||
|
func (s *RouterTestSuite) TestLambdaRequestIDExtraction() {
|
||||||
|
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/test", true, handler)
|
||||||
|
|
||||||
|
// Test with valid Lambda context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.Header.Set(httpheaders.XAmznRequestContextHeader, `{"requestId":"lambda-req-123"}`)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
s.Require().Equal("lambda-req-123", rw.Header().Get(httpheaders.XRequestID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test IP address handling
|
||||||
|
func (s *RouterTestSuite) TestReplaceIP() {
|
||||||
|
var capturedRemoteAddr string
|
||||||
|
handler := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
capturedRemoteAddr = req.RemoteAddr
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/test", true, handler)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
originalAddr string
|
||||||
|
headers map[string]string
|
||||||
|
expectedAddr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "CFConnectingIP",
|
||||||
|
originalAddr: "original:8080",
|
||||||
|
headers: map[string]string{
|
||||||
|
httpheaders.CFConnectingIP: "1.2.3.4",
|
||||||
|
},
|
||||||
|
expectedAddr: "1.2.3.4:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "XForwardedForMulti",
|
||||||
|
originalAddr: "original:8080",
|
||||||
|
headers: map[string]string{
|
||||||
|
httpheaders.XForwardedFor: "5.6.7.8, 9.10.11.12",
|
||||||
|
},
|
||||||
|
expectedAddr: "5.6.7.8:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "XForwardedForSingle",
|
||||||
|
originalAddr: "original:8080",
|
||||||
|
headers: map[string]string{
|
||||||
|
httpheaders.XForwardedFor: "13.14.15.16",
|
||||||
|
},
|
||||||
|
expectedAddr: "13.14.15.16:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "XRealIP",
|
||||||
|
originalAddr: "original:8080",
|
||||||
|
headers: map[string]string{
|
||||||
|
httpheaders.XRealIP: "17.18.19.20",
|
||||||
|
},
|
||||||
|
expectedAddr: "17.18.19.20:8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.RemoteAddr = tt.originalAddr
|
||||||
|
|
||||||
|
for header, value := range tt.headers {
|
||||||
|
req.Header.Set(header, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.router.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
s.Require().Equal(tt.expectedAddr, capturedRemoteAddr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
82
server/server.go
Normal file
82
server/server.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
golog "log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/netutil"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/reuseport"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxHeaderBytes represents max bytes in request header
|
||||||
|
maxHeaderBytes = 1 << 20
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server represents the HTTP server wrapper struct
|
||||||
|
type Server struct {
|
||||||
|
router *Router
|
||||||
|
server *http.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the http server. cancel is called in case server failed to start, but it happened
|
||||||
|
// asynchronously. It should cancel the upstream context.
|
||||||
|
func Start(cancel context.CancelFunc, router *Router) (*Server, error) {
|
||||||
|
l, err := reuseport.Listen(router.config.Network, router.config.Bind, router.config.SocketReusePort)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil, fmt.Errorf("can't start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if router.config.MaxClients > 0 {
|
||||||
|
l = netutil.LimitListener(l, router.config.MaxClients)
|
||||||
|
}
|
||||||
|
|
||||||
|
errLogger := golog.New(
|
||||||
|
log.WithField("source", "http_server").WriterLevel(log.ErrorLevel),
|
||||||
|
"", 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: router,
|
||||||
|
ReadTimeout: router.config.ReadRequestTimeout,
|
||||||
|
MaxHeaderBytes: maxHeaderBytes,
|
||||||
|
ErrorLog: errLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.KeepAliveTimeout > 0 {
|
||||||
|
srv.IdleTimeout = router.config.KeepAliveTimeout
|
||||||
|
} else {
|
||||||
|
srv.SetKeepAlivesEnabled(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Infof("Starting server at %s", router.config.Bind)
|
||||||
|
|
||||||
|
if err := srv.Serve(l); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
router: router,
|
||||||
|
server: srv,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) {
|
||||||
|
log.Info("Shutting down the server...")
|
||||||
|
|
||||||
|
ctx, close := context.WithTimeout(ctx, s.router.config.GracefulTimeout)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
s.server.Shutdown(ctx)
|
||||||
|
}
|
268
server/server_test.go
Normal file
268
server/server_test.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
config *Config
|
||||||
|
blankRouter *Router
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) SetupTest() {
|
||||||
|
config.Reset()
|
||||||
|
s.config = NewConfigFromEnv()
|
||||||
|
s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
|
||||||
|
s.blankRouter = NewRouter(s.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
|
||||||
|
ctx, cancel := context.WithCancel(s.T().Context())
|
||||||
|
|
||||||
|
// Track if cancel was called using atomic
|
||||||
|
var cancelCalled atomic.Bool
|
||||||
|
cancelWrapper := func() {
|
||||||
|
cancel()
|
||||||
|
cancelCalled.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidConfig := &Config{
|
||||||
|
Network: "tcp",
|
||||||
|
Bind: "invalid-address", // Invalid address
|
||||||
|
}
|
||||||
|
|
||||||
|
r := NewRouter(invalidConfig)
|
||||||
|
|
||||||
|
server, err := Start(cancelWrapper, r)
|
||||||
|
|
||||||
|
s.Require().Error(err)
|
||||||
|
s.Nil(server)
|
||||||
|
s.Contains(err.Error(), "can't start server")
|
||||||
|
|
||||||
|
// Check if cancel was called using Eventually
|
||||||
|
s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Also verify the context was cancelled
|
||||||
|
s.Require().Eventually(func() bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestShutdown() {
|
||||||
|
_, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
server, err := Start(cancel, s.blankRouter)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.NotNil(server)
|
||||||
|
|
||||||
|
// Test graceful shutdown
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
// Should not panic or hang
|
||||||
|
s.NotPanics(func() {
|
||||||
|
server.Shutdown(shutdownCtx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestWithCORS() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
corsAllowOrigin string
|
||||||
|
expectedOrigin string
|
||||||
|
expectedMethods string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "WithCORSOrigin",
|
||||||
|
corsAllowOrigin: "https://example.com",
|
||||||
|
expectedOrigin: "https://example.com",
|
||||||
|
expectedMethods: "GET, OPTIONS",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoCORSOrigin",
|
||||||
|
corsAllowOrigin: "",
|
||||||
|
expectedOrigin: "",
|
||||||
|
expectedMethods: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
config := &Config{
|
||||||
|
CORSAllowOrigin: tt.corsAllowOrigin,
|
||||||
|
}
|
||||||
|
router := NewRouter(config)
|
||||||
|
|
||||||
|
wrappedHandler := router.WithCORS(s.mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrappedHandler("test-req-id", rw, req)
|
||||||
|
|
||||||
|
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
|
||||||
|
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestWithSecret() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
secret string
|
||||||
|
authHeader string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ValidSecret",
|
||||||
|
secret: "test-secret",
|
||||||
|
authHeader: "Bearer test-secret",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "InvalidSecret",
|
||||||
|
secret: "foo-secret",
|
||||||
|
authHeader: "Bearer wrong-secret",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoSecretConfigured",
|
||||||
|
secret: "",
|
||||||
|
authHeader: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
config := &Config{
|
||||||
|
Secret: tt.secret,
|
||||||
|
}
|
||||||
|
router := NewRouter(config)
|
||||||
|
|
||||||
|
wrappedHandler := router.WithSecret(s.mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
if tt.authHeader != "" {
|
||||||
|
req.Header.Set(httpheaders.Authorization, tt.authHeader)
|
||||||
|
}
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err := wrappedHandler("test-req-id", rw, req)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
s.Require().Error(err)
|
||||||
|
} else {
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestIntoSuccess() {
|
||||||
|
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrappedHandler("test-req-id", rw, req)
|
||||||
|
|
||||||
|
s.Equal(http.StatusOK, rw.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestIntoWithError() {
|
||||||
|
testError := errors.New("test error")
|
||||||
|
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
return testError
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.blankRouter.WithReportError(mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrappedHandler("test-req-id", rw, req)
|
||||||
|
|
||||||
|
s.Equal(http.StatusInternalServerError, rw.Code)
|
||||||
|
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestIntoPanicWithError() {
|
||||||
|
testError := errors.New("panic error")
|
||||||
|
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
panic(testError)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.NotPanics(func() {
|
||||||
|
err := wrappedHandler("test-req-id", rw, req)
|
||||||
|
s.Require().Error(err, "panic error")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Equal(http.StatusOK, rw.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
|
||||||
|
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Should re-panic with ErrAbortHandler
|
||||||
|
s.Panics(func() {
|
||||||
|
wrappedHandler("test-req-id", rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
|
||||||
|
mockHandler := func(reqID string, rw http.ResponseWriter, r *http.Request) error {
|
||||||
|
panic("string panic")
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.blankRouter.WithPanic(mockHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Should re-panic with non-error panics
|
||||||
|
s.NotPanics(func() {
|
||||||
|
err := wrappedHandler("test-req-id", rw, req)
|
||||||
|
s.Require().Error(err, "string panic")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerTestSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ServerTestSuite))
|
||||||
|
}
|
52
server/timeout_response.go
Normal file
52
server/timeout_response.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// timeoutResponse manages response writer with timeout. It has
|
||||||
|
// timeout on all write methods.
|
||||||
|
type timeoutResponse struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
controller *http.ResponseController
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTimeoutResponse creates a new timeoutResponse
|
||||||
|
func newTimeoutResponse(rw http.ResponseWriter, timeout time.Duration) http.ResponseWriter {
|
||||||
|
return &timeoutResponse{
|
||||||
|
ResponseWriter: rw,
|
||||||
|
controller: http.NewResponseController(rw),
|
||||||
|
timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements http.ResponseWriter.Write
|
||||||
|
func (rw *timeoutResponse) Write(b []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
rw.withWriteDeadline(func() {
|
||||||
|
n, err = rw.ResponseWriter.Write(b)
|
||||||
|
})
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header returns current HTTP headers
|
||||||
|
func (rw *timeoutResponse) Header() http.Header {
|
||||||
|
return rw.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// withWriteDeadline executes a Write* function with a deadline
|
||||||
|
func (rw *timeoutResponse) withWriteDeadline(f func()) {
|
||||||
|
deadline := time.Now().Add(rw.timeout)
|
||||||
|
|
||||||
|
// Set write deadline
|
||||||
|
rw.controller.SetWriteDeadline(deadline)
|
||||||
|
|
||||||
|
// Reset write deadline after method has finished
|
||||||
|
defer rw.controller.SetWriteDeadline(time.Time{})
|
||||||
|
f()
|
||||||
|
}
|
@@ -1,4 +1,6 @@
|
|||||||
package router
|
// timer.go contains methods for storing, retrieving and checking
|
||||||
|
// timer in a request context.
|
||||||
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,8 +11,10 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/ierrors"
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// timerSinceCtxKey represents a context key for start time.
|
||||||
type timerSinceCtxKey struct{}
|
type timerSinceCtxKey struct{}
|
||||||
|
|
||||||
|
// startRequestTimer starts a new request timer.
|
||||||
func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
|
func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now())
|
ctx = context.WithValue(ctx, timerSinceCtxKey{}, time.Now())
|
||||||
@@ -18,17 +22,20 @@ func startRequestTimer(r *http.Request) (*http.Request, context.CancelFunc) {
|
|||||||
return r.WithContext(ctx), cancel
|
return r.WithContext(ctx), cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
func ctxTime(ctx context.Context) time.Duration {
|
// requestStartedAt returns the duration since the timer started in the context.
|
||||||
|
func requestStartedAt(ctx context.Context) time.Duration {
|
||||||
if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok {
|
if t, ok := ctx.Value(timerSinceCtxKey{}).(time.Time); ok {
|
||||||
return time.Since(t)
|
return time.Since(t)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckTimeout checks if the request context has timed out or cancelled and returns
|
||||||
|
// wrapped error.
|
||||||
func CheckTimeout(ctx context.Context) error {
|
func CheckTimeout(ctx context.Context) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
d := ctxTime(ctx)
|
d := requestStartedAt(ctx)
|
||||||
|
|
||||||
err := ctx.Err()
|
err := ctx.Err()
|
||||||
switch err {
|
switch err {
|
||||||
@@ -37,7 +44,7 @@ func CheckTimeout(ctx context.Context) error {
|
|||||||
case context.DeadlineExceeded:
|
case context.DeadlineExceeded:
|
||||||
return newRequestTimeoutError(d)
|
return newRequestTimeoutError(d)
|
||||||
default:
|
default:
|
||||||
return ierrors.Wrap(err, 0)
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryTimeout))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
67
server/timer_test.go
Normal file
67
server/timer_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func() context.Context
|
||||||
|
fail bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "WithoutTimeout",
|
||||||
|
setup: context.Background,
|
||||||
|
fail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ActiveTimerContext",
|
||||||
|
setup: func() context.Context {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
newReq, _ := startRequestTimer(req)
|
||||||
|
return newReq.Context()
|
||||||
|
},
|
||||||
|
fail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CancelledContext",
|
||||||
|
setup: func() context.Context {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
newReq, cancel := startRequestTimer(req)
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
return newReq.Context()
|
||||||
|
},
|
||||||
|
fail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DeadlineExceeded",
|
||||||
|
setup: func() context.Context {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
|
||||||
|
defer cancel()
|
||||||
|
time.Sleep(time.Millisecond * 10) // Ensure timeout
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
fail: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := tt.setup()
|
||||||
|
err := CheckTimeout(ctx)
|
||||||
|
|
||||||
|
if tt.fail {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
21
stream.go
21
stream.go
@@ -12,11 +12,12 @@ import (
|
|||||||
"github.com/imgproxy/imgproxy/v3/config"
|
"github.com/imgproxy/imgproxy/v3/config"
|
||||||
"github.com/imgproxy/imgproxy/v3/cookies"
|
"github.com/imgproxy/imgproxy/v3/cookies"
|
||||||
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
"github.com/imgproxy/imgproxy/v3/httpheaders"
|
||||||
|
"github.com/imgproxy/imgproxy/v3/ierrors"
|
||||||
"github.com/imgproxy/imgproxy/v3/imagedata"
|
"github.com/imgproxy/imgproxy/v3/imagedata"
|
||||||
"github.com/imgproxy/imgproxy/v3/metrics"
|
"github.com/imgproxy/imgproxy/v3/metrics"
|
||||||
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
"github.com/imgproxy/imgproxy/v3/metrics/stats"
|
||||||
"github.com/imgproxy/imgproxy/v3/options"
|
"github.com/imgproxy/imgproxy/v3/options"
|
||||||
"github.com/imgproxy/imgproxy/v3/router"
|
"github.com/imgproxy/imgproxy/v3/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -44,7 +45,7 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) {
|
func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, imageURL string) error {
|
||||||
stats.IncImagesInProgress()
|
stats.IncImagesInProgress()
|
||||||
defer stats.DecImagesInProgress()
|
defer stats.DecImagesInProgress()
|
||||||
|
|
||||||
@@ -65,18 +66,24 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
|||||||
|
|
||||||
if config.CookiePassthrough {
|
if config.CookiePassthrough {
|
||||||
cookieJar, err = cookies.JarFromRequest(r)
|
cookieJar, err = cookies.JarFromRequest(r)
|
||||||
checkErr(ctx, "streaming", err)
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
|
req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
|
||||||
defer req.Cancel()
|
defer req.Cancel()
|
||||||
checkErr(ctx, "streaming", err)
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||||
|
}
|
||||||
|
|
||||||
res, err := req.Send()
|
res, err := req.Send()
|
||||||
if res != nil {
|
if res != nil {
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
}
|
}
|
||||||
checkErr(ctx, "streaming", err)
|
if err != nil {
|
||||||
|
return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryStreaming))
|
||||||
|
}
|
||||||
|
|
||||||
for _, k := range streamRespHeaders {
|
for _, k := range streamRespHeaders {
|
||||||
vv := res.Header.Values(k)
|
vv := res.Header.Values(k)
|
||||||
@@ -116,7 +123,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
|||||||
copyerr = nil
|
copyerr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
router.LogResponse(
|
server.LogResponse(
|
||||||
reqID, r, res.StatusCode, nil,
|
reqID, r, res.StatusCode, nil,
|
||||||
log.Fields{
|
log.Fields{
|
||||||
"image_url": imageURL,
|
"image_url": imageURL,
|
||||||
@@ -127,4 +134,6 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
|
|||||||
if copyerr != nil {
|
if copyerr != nil {
|
||||||
panic(http.ErrAbortHandler)
|
panic(http.ErrAbortHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user