diff --git a/router/router.go b/router/router.go index 37ddccf0..6f285c61 100644 --- a/router/router.go +++ b/router/router.go @@ -19,7 +19,6 @@ var ( ) type RouteHandler func(string, http.ResponseWriter, *http.Request) -type PanicHandler func(string, http.ResponseWriter, *http.Request, error) type route struct { Method string @@ -29,9 +28,8 @@ type route struct { } type Router struct { - prefix string - Routes []*route - PanicHandler PanicHandler + prefix string + Routes []*route } func (r *route) isMatch(req *http.Request) bool { @@ -95,16 +93,6 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { replaceRemoteAddr(req, ip) } - defer func() { - if rerr := recover(); rerr != nil { - if err, ok := rerr.(error); ok && r.PanicHandler != nil { - r.PanicHandler(reqID, rw, req, err) - } else { - panic(rerr) - } - } - }() - LogRequest(reqID, req) for _, rr := range r.Routes { diff --git a/server.go b/server.go index dcc775e3..4f8f0a62 100644 --- a/server.go +++ b/server.go @@ -26,12 +26,10 @@ var ( func buildRouter() *router.Router { r := router.New(config.PathPrefix) - r.PanicHandler = handlePanic - r.GET("/", handleLanding, true) r.GET("/health", handleHealth, true) r.GET("/favicon.ico", handleFavicon, true) - r.GET("/", withCORS(withSecret(handleProcessing)), false) + r.GET("/", withCORS(withPanicHandler(withSecret(handleProcessing))), false) r.HEAD("/", withCORS(handleHead), false) r.OPTIONS("/", withCORS(handleHead), false) @@ -104,21 +102,34 @@ func withSecret(h router.RouteHandler) router.RouteHandler { } } -func handlePanic(reqID string, rw http.ResponseWriter, r *http.Request, err error) { - ierr := ierrors.Wrap(err, 3) +func withPanicHandler(h router.RouteHandler) router.RouteHandler { + return func(reqID string, rw http.ResponseWriter, r *http.Request) { + defer func() { + if rerr := recover(); rerr != nil { + err, ok := rerr.(error) + if !ok { + panic(rerr) + } - if ierr.Unexpected { - errorreport.Report(err, r) - } + ierr := ierrors.Wrap(err, 3) - router.LogResponse(reqID, r, ierr.StatusCode, ierr) + if ierr.Unexpected { + errorreport.Report(err, r) + } - rw.WriteHeader(ierr.StatusCode) + router.LogResponse(reqID, r, ierr.StatusCode, ierr) - if config.DevelopmentErrorsMode { - rw.Write([]byte(ierr.Message)) - } else { - rw.Write([]byte(ierr.PublicMessage)) + rw.WriteHeader(ierr.StatusCode) + + if config.DevelopmentErrorsMode { + rw.Write([]byte(ierr.Message)) + } else { + rw.Write([]byte(ierr.PublicMessage)) + } + } + }() + + h(reqID, rw, r) } }