use "clean" HTTP transport; Use context for downloading timeout control

This commit is contained in:
DarthSim
2023-03-21 20:58:16 +03:00
parent dde81b49f7
commit 24f4d43a0f
7 changed files with 104 additions and 50 deletions

View File

@@ -2,9 +2,11 @@ package imagedata
import ( import (
"compress/gzip" "compress/gzip"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"time" "time"
@@ -55,16 +57,25 @@ func (e *ErrorNotModified) Error() string {
} }
func initDownloading() error { func initDownloading() error {
transport := http.DefaultTransport.(*http.Transport).Clone() transport := &http.Transport{
transport.DisableCompression = true Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: config.Concurrency + 1,
IdleConnTimeout: time.Duration(config.ClientKeepAliveTimeout) * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
DisableCompression: true,
}
if config.ClientKeepAliveTimeout > 0 { if config.ClientKeepAliveTimeout <= 0 {
transport.MaxIdleConns = config.Concurrency transport.MaxIdleConnsPerHost = -1
transport.MaxIdleConnsPerHost = config.Concurrency transport.DisableKeepAlives = true
transport.IdleConnTimeout = time.Duration(config.ClientKeepAliveTimeout) * time.Second
} else {
transport.MaxIdleConns = 0
transport.MaxIdleConnsPerHost = 0
} }
if config.IgnoreSslVerification { if config.IgnoreSslVerification {
@@ -113,7 +124,6 @@ func initDownloading() error {
} }
downloadClient = &http.Client{ downloadClient = &http.Client{
Timeout: time.Duration(config.DownloadTimeout) * time.Second,
Transport: transport, Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
redirects := len(via) redirects := len(via)
@@ -139,14 +149,18 @@ func headersToStore(res *http.Response) map[string]string {
return m return m
} }
func BuildImageRequest(imageURL string, header http.Header, jar *cookiejar.Jar) (*http.Request, error) { func BuildImageRequest(ctx context.Context, imageURL string, header http.Header, jar *cookiejar.Jar) (*http.Request, context.CancelFunc, error) {
req, err := http.NewRequest("GET", imageURL, nil) reqCtx, reqCancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second)
req, err := http.NewRequestWithContext(reqCtx, "GET", imageURL, nil)
if err != nil { if err != nil {
return nil, ierrors.New(404, err.Error(), msgSourceImageIsUnreachable) reqCancel()
return nil, func() {}, ierrors.New(404, err.Error(), msgSourceImageIsUnreachable)
} }
if _, ok := enabledSchemes[req.URL.Scheme]; !ok { if _, ok := enabledSchemes[req.URL.Scheme]; !ok {
return nil, ierrors.New( reqCancel()
return nil, func() {}, ierrors.New(
404, 404,
fmt.Sprintf("Unknown scheme: %s", req.URL.Scheme), fmt.Sprintf("Unknown scheme: %s", req.URL.Scheme),
msgSourceImageIsUnreachable, msgSourceImageIsUnreachable,
@@ -167,37 +181,41 @@ func BuildImageRequest(imageURL string, header http.Header, jar *cookiejar.Jar)
} }
} }
return req, nil return req, reqCancel, nil
} }
func SendRequest(req *http.Request) (*http.Response, error) { func SendRequest(req *http.Request) (*http.Response, error) {
res, err := downloadClient.Do(req) res, err := downloadClient.Do(req)
if err != nil { if err != nil {
return nil, ierrors.New(500, checkTimeoutErr(err).Error(), msgSourceImageIsUnreachable) return nil, wrapError(err)
} }
return res, nil return res, nil
} }
func requestImage(imageURL string, opts DownloadOptions) (*http.Response, error) { func requestImage(ctx context.Context, imageURL string, opts DownloadOptions) (*http.Response, context.CancelFunc, error) {
req, err := BuildImageRequest(imageURL, opts.Header, opts.CookieJar) req, reqCancel, err := BuildImageRequest(ctx, imageURL, opts.Header, opts.CookieJar)
if err != nil { if err != nil {
return nil, err reqCancel()
return nil, func() {}, err
} }
res, err := SendRequest(req) res, err := SendRequest(req)
if err != nil { if err != nil {
return nil, err reqCancel()
return nil, func() {}, err
} }
if res.StatusCode == http.StatusNotModified { if res.StatusCode == http.StatusNotModified {
res.Body.Close() res.Body.Close()
return nil, &ErrorNotModified{Message: "Not Modified", Headers: headersToStore(res)} reqCancel()
return nil, func() {}, &ErrorNotModified{Message: "Not Modified", Headers: headersToStore(res)}
} }
if res.StatusCode != 200 { if res.StatusCode != 200 {
body, _ := io.ReadAll(res.Body) body, _ := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
reqCancel()
status := 404 status := 404
if res.StatusCode >= 500 { if res.StatusCode >= 500 {
@@ -205,19 +223,21 @@ func requestImage(imageURL string, opts DownloadOptions) (*http.Response, error)
} }
msg := fmt.Sprintf("Status: %d; %s", res.StatusCode, string(body)) msg := fmt.Sprintf("Status: %d; %s", res.StatusCode, string(body))
return nil, ierrors.New(status, msg, msgSourceImageIsUnreachable) return nil, func() {}, ierrors.New(status, msg, msgSourceImageIsUnreachable)
} }
return res, nil return res, reqCancel, nil
} }
func download(imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) { func download(ctx context.Context, imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
// We use this for testing // We use this for testing
if len(redirectAllRequestsTo) > 0 { if len(redirectAllRequestsTo) > 0 {
imageURL = redirectAllRequestsTo imageURL = redirectAllRequestsTo
} }
res, err := requestImage(imageURL, opts) res, reqCancel, err := requestImage(ctx, imageURL, opts)
defer reqCancel()
if res != nil { if res != nil {
defer res.Body.Close() defer res.Body.Close()
} }

43
imagedata/error.go Normal file
View File

@@ -0,0 +1,43 @@
package imagedata
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/imgproxy/imgproxy/v3/ierrors"
)
type httpError interface {
Timeout() bool
}
func wrapError(err error) error {
isTimeout := false
if errors.Is(err, context.Canceled) {
return ierrors.New(
499,
fmt.Sprintf("The image request is cancelled: %s", err),
msgSourceImageIsUnreachable,
)
} else if errors.Is(err, context.DeadlineExceeded) {
isTimeout = true
} else if httpErr, ok := err.(httpError); ok {
isTimeout = httpErr.Timeout()
}
if !isTimeout {
return err
}
ierr := ierrors.New(
http.StatusGatewayTimeout,
fmt.Sprintf("The image request timed out: %s", err),
msgSourceImageIsUnreachable,
)
ierr.Unexpected = true
return ierr
}

View File

@@ -70,7 +70,7 @@ func loadWatermark() (err error) {
} }
if len(config.WatermarkURL) > 0 { if len(config.WatermarkURL) > 0 {
Watermark, err = Download(config.WatermarkURL, "watermark", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions()) Watermark, err = Download(context.Background(), config.WatermarkURL, "watermark", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
return return
} }
@@ -84,7 +84,7 @@ func loadFallbackImage() (err error) {
case len(config.FallbackImagePath) > 0: case len(config.FallbackImagePath) > 0:
FallbackImage, err = FromFile(config.FallbackImagePath, "fallback image", security.DefaultOptions()) FallbackImage, err = FromFile(config.FallbackImagePath, "fallback image", security.DefaultOptions())
case len(config.FallbackImageURL) > 0: case len(config.FallbackImageURL) > 0:
FallbackImage, err = Download(config.FallbackImageURL, "fallback image", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions()) FallbackImage, err = Download(context.Background(), config.FallbackImageURL, "fallback image", DownloadOptions{Header: nil, CookieJar: nil}, security.DefaultOptions())
default: default:
FallbackImage, err = nil, nil FallbackImage, err = nil, nil
} }
@@ -130,8 +130,8 @@ func FromFile(path, desc string, secopts security.Options) (*ImageData, error) {
return imgdata, nil return imgdata, nil
} }
func Download(imageURL, desc string, opts DownloadOptions, secopts security.Options) (*ImageData, error) { func Download(ctx context.Context, imageURL, desc string, opts DownloadOptions, secopts security.Options) (*ImageData, error) {
imgdata, err := download(imageURL, opts, secopts) imgdata, err := download(ctx, imageURL, opts, secopts)
if err != nil { if err != nil {
if nmErr, ok := err.(*ErrorNotModified); ok { if nmErr, ok := err.(*ErrorNotModified); ok {
nmErr.Message = fmt.Sprintf("Can't download %s: %s", desc, nmErr.Message) nmErr.Message = fmt.Sprintf("Can't download %s: %s", desc, nmErr.Message)

View File

@@ -42,13 +42,14 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
return nil, ErrSourceImageTypeNotSupported return nil, ErrSourceImageTypeNotSupported
} }
return nil, checkTimeoutErr(err) return nil, wrapError(err)
} }
if err = security.CheckDimensions(meta.Width(), meta.Height(), 1, secopts); err != nil { if err = security.CheckDimensions(meta.Width(), meta.Height(), 1, secopts); err != nil {
buf.Reset() buf.Reset()
cancel() cancel()
return nil, err
return nil, wrapError(err)
} }
if contentLength > buf.Cap() { if contentLength > buf.Cap() {
@@ -56,8 +57,10 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options)
} }
if err = br.Flush(); err != nil { if err = br.Flush(); err != nil {
buf.Reset()
cancel() cancel()
return nil, checkTimeoutErr(err)
return nil, wrapError(err)
} }
return &ImageData{ return &ImageData{

View File

@@ -1,14 +0,0 @@
package imagedata
import "errors"
type httpError interface {
Timeout() bool
}
func checkTimeoutErr(err error) error {
if httpErr, ok := err.(httpError); ok && httpErr.Timeout() {
return errors.New("The image request timed out")
}
return err
}

View File

@@ -303,7 +303,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) {
checkErr(ctx, "download", err) checkErr(ctx, "download", err)
} }
return imagedata.Download(imageURL, "source image", downloadOpts, po.SecurityOptions) return imagedata.Download(ctx, imageURL, "source image", downloadOpts, po.SecurityOptions)
}() }()
if err == nil { if err == nil {

View File

@@ -71,13 +71,15 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht
checkErr(ctx, "streaming", err) checkErr(ctx, "streaming", err)
} }
req, err := imagedata.BuildImageRequest(imageURL, imgRequestHeader, cookieJar) req, reqCancel, err := imagedata.BuildImageRequest(r.Context(), imageURL, imgRequestHeader, cookieJar)
defer reqCancel()
checkErr(ctx, "streaming", err) checkErr(ctx, "streaming", err)
res, err := imagedata.SendRequest(req) res, err := imagedata.SendRequest(req)
checkErr(ctx, "streaming", err) if res != nil {
defer res.Body.Close() defer res.Body.Close()
}
checkErr(ctx, "streaming", err)
for _, k := range streamRespHeaders { for _, k := range streamRespHeaders {
vv := res.Header.Values(k) vv := res.Header.Values(k)