From dd3b430f8754a2d1ca89a97d08469023d1b3e61d Mon Sep 17 00:00:00 2001 From: Victor Sokolov Date: Fri, 25 Jul 2025 12:26:21 +0200 Subject: [PATCH] transport isolated, imagefetcher introduced (#1465) --- go.mod | 1 + go.sum | 2 + imagedata/download.go | 302 +++++--------------------- imagedata/read.go | 13 +- {imagedata => imagefetcher}/errors.go | 13 +- imagefetcher/fetcher.go | 86 ++++++++ imagefetcher/request.go | 204 +++++++++++++++++ processing_handler.go | 11 +- security/file_size.go | 38 ---- security/response_limit.go | 51 +++++ stream.go | 8 +- transport/azure/azure.go | 4 +- transport/gcs/gcs.go | 4 +- transport/generichttp/generic_http.go | 59 +++++ transport/s3/s3.go | 4 +- transport/swift/swift.go | 4 +- transport/transport.go | 135 ++++++++---- 17 files changed, 578 insertions(+), 361 deletions(-) rename {imagedata => imagefetcher}/errors.go (92%) create mode 100644 imagefetcher/fetcher.go create mode 100644 imagefetcher/request.go delete mode 100644 security/file_size.go create mode 100644 security/response_limit.go create mode 100644 transport/generichttp/generic_http.go diff --git a/go.mod b/go.mod index f4246bf0..66bd99f4 100644 --- a/go.mod +++ b/go.mod @@ -205,6 +205,7 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect + go.withmatt.com/httpheaders v1.0.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.39.0 // indirect diff --git a/go.sum b/go.sum index ba8356b2..4214fec9 100644 --- a/go.sum +++ b/go.sum @@ -560,6 +560,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.withmatt.com/httpheaders v1.0.0 h1:xZhtLWyIWCd8FT3CvUBRQLhQpgZaMmHNfIIT0wwNc1A= +go.withmatt.com/httpheaders v1.0.0/go.mod h1:bKAYNgm9s2ViHIoGOnMKo4F2zJXBdvpfGuSEJQYF8pQ= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= diff --git a/imagedata/download.go b/imagedata/download.go index d955d76e..27a233da 100644 --- a/imagedata/download.go +++ b/imagedata/download.go @@ -1,51 +1,36 @@ package imagedata import ( - "compress/gzip" "context" - "io" "net/http" - "net/http/cookiejar" - "regexp" - "strconv" - "strings" - "time" + "slices" "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/imagefetcher" "github.com/imgproxy/imgproxy/v3/security" - - defaultTransport "github.com/imgproxy/imgproxy/v3/transport" - azureTransport "github.com/imgproxy/imgproxy/v3/transport/azure" - transportCommon "github.com/imgproxy/imgproxy/v3/transport/common" - fsTransport "github.com/imgproxy/imgproxy/v3/transport/fs" - gcsTransport "github.com/imgproxy/imgproxy/v3/transport/gcs" - s3Transport "github.com/imgproxy/imgproxy/v3/transport/s3" - swiftTransport "github.com/imgproxy/imgproxy/v3/transport/swift" + "github.com/imgproxy/imgproxy/v3/transport" + "go.withmatt.com/httpheaders" ) var ( - downloadClient *http.Client - - enabledSchemes = map[string]struct{}{ - "http": {}, - "https": {}, - } - - imageHeadersToStore = []string{ - "Cache-Control", - "Expires", - "ETag", - "Last-Modified", - } - - contentRangeRe = regexp.MustCompile(`^bytes ((\d+)-(\d+)|\*)/(\d+|\*)$`) + Fetcher *imagefetcher.Fetcher // For tests redirectAllRequestsTo string -) -const msgSourceImageIsUnreachable = "Source image is unreachable" + // keepResponseHeaders is a list of HTTP headers that should be preserved in the response + keepResponseHeaders = []string{ + httpheaders.CacheControl, + httpheaders.Expires, + httpheaders.LastModified, + // NOTE: + // httpheaders.Etag == "Etag". + // Http header names are case-insensitive, but we rely on the case in most cases. + // We must migrate to http.Headers and the subsequent methods everywhere. + httpheaders.Etag, + } +) type DownloadOptions struct { Header http.Header @@ -53,224 +38,40 @@ type DownloadOptions struct { } func initDownloading() error { - transport, err := defaultTransport.New(true) + ts, err := transport.NewTransport() if err != nil { return err } - registerProtocol := func(scheme string, rt http.RoundTripper) { - transport.RegisterProtocol(scheme, rt) - enabledSchemes[scheme] = struct{}{} - } - - if config.LocalFileSystemRoot != "" { - registerProtocol("local", fsTransport.New()) - } - - if config.S3Enabled { - if t, err := s3Transport.New(); err != nil { - return err - } else { - registerProtocol("s3", t) - } - } - - if config.GCSEnabled { - if t, err := gcsTransport.New(); err != nil { - return err - } else { - registerProtocol("gs", t) - } - } - - if config.ABSEnabled { - if t, err := azureTransport.New(); err != nil { - return err - } else { - registerProtocol("abs", t) - } - } - - if config.SwiftEnabled { - if t, err := swiftTransport.New(); err != nil { - return err - } else { - registerProtocol("swift", t) - } - } - - downloadClient = &http.Client{ - Transport: transport, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - redirects := len(via) - if redirects >= config.MaxRedirects { - return newImageTooManyRedirectsError(redirects) - } - return nil - }, + Fetcher, err = imagefetcher.NewFetcher(ts, config.MaxRedirects) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create image fetcher")) } return nil } -func headersToStore(res *http.Response) map[string]string { - m := make(map[string]string) - - for _, h := range imageHeadersToStore { - if val := res.Header.Get(h); len(val) != 0 { - m[h] = val - } - } - - return m -} - -func BuildImageRequest(ctx context.Context, imageURL string, header http.Header, jar http.CookieJar) (*http.Request, context.CancelFunc, error) { - reqCtx, reqCancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second) - - imageURL = transportCommon.EscapeURL(imageURL) - - req, err := http.NewRequestWithContext(reqCtx, "GET", imageURL, nil) - if err != nil { - reqCancel() - return nil, func() {}, newImageRequestError(err) - } - - if _, ok := enabledSchemes[req.URL.Scheme]; !ok { - reqCancel() - return nil, func() {}, newImageRequstSchemeError(req.URL.Scheme) - } - - if jar != nil { - for _, cookie := range jar.Cookies(req.URL) { - req.AddCookie(cookie) - } - } - - req.Header.Set("User-Agent", config.UserAgent) - - for k, v := range header { - if len(v) > 0 { - req.Header.Set(k, v[0]) - } - } - - return req, reqCancel, nil -} - -func SendRequest(req *http.Request) (*http.Response, error) { - var client *http.Client - if req.URL.Scheme == "http" || req.URL.Scheme == "https" { - clientCopy := *downloadClient - - jar, err := cookiejar.New(nil) - if err != nil { - return nil, err - } - clientCopy.Jar = jar - client = &clientCopy - } else { - client = downloadClient - } - - for { - res, err := client.Do(req) - if err == nil { - return res, nil - } - - if res != nil && res.Body != nil { - res.Body.Close() - } - - if strings.Contains(err.Error(), "client connection lost") { - select { - case <-req.Context().Done(): - return nil, err - case <-time.After(100 * time.Microsecond): - continue - } - } - - return nil, wrapError(err) - } -} - -func requestImage(ctx context.Context, imageURL string, opts DownloadOptions) (*http.Response, context.CancelFunc, error) { - req, reqCancel, err := BuildImageRequest(ctx, imageURL, opts.Header, opts.CookieJar) - if err != nil { - reqCancel() - return nil, func() {}, err - } - - res, err := SendRequest(req) - if err != nil { - reqCancel() - return nil, func() {}, err - } - - if res.StatusCode == http.StatusNotModified { - res.Body.Close() - reqCancel() - return nil, func() {}, newNotModifiedError(headersToStore(res)) - } - - // If the source responds with 206, check if the response contains entire image. - // If not, return an error. - if res.StatusCode == http.StatusPartialContent { - contentRange := res.Header.Get("Content-Range") - rangeParts := contentRangeRe.FindStringSubmatch(contentRange) - if len(rangeParts) == 0 { - res.Body.Close() - reqCancel() - return nil, func() {}, newImagePartialResponseError("Partial response with invalid Content-Range header") - } - - if rangeParts[1] == "*" || rangeParts[2] != "0" { - res.Body.Close() - reqCancel() - return nil, func() {}, newImagePartialResponseError("Partial response with incomplete content") - } - - contentLengthStr := rangeParts[4] - if contentLengthStr == "*" { - contentLengthStr = res.Header.Get("Content-Length") - } - - contentLength, _ := strconv.Atoi(contentLengthStr) - rangeEnd, _ := strconv.Atoi(rangeParts[3]) - - if contentLength <= 0 || rangeEnd != contentLength-1 { - res.Body.Close() - reqCancel() - return nil, func() {}, newImagePartialResponseError("Partial response with incomplete content") - } - } else if res.StatusCode != http.StatusOK { - var body string - - if strings.HasPrefix(res.Header.Get("Content-Type"), "text/") { - bbody, _ := io.ReadAll(io.LimitReader(res.Body, 1024)) - body = string(bbody) - } - - res.Body.Close() - reqCancel() - - return nil, func() {}, newImageResponseStatusError(res.StatusCode, body) - } - - return res, reqCancel, nil -} - func download(ctx context.Context, imageURL string, opts DownloadOptions, secopts security.Options) (*ImageData, error) { // We use this for testing if len(redirectAllRequestsTo) > 0 { imageURL = redirectAllRequestsTo } - res, reqCancel, err := requestImage(ctx, imageURL, opts) - defer reqCancel() + req, err := Fetcher.BuildRequest(ctx, imageURL, opts.Header, opts.CookieJar) + if err != nil { + return nil, err + } + defer req.Cancel() + res, err := req.FetchImage() + if err != nil { + if res != nil { + res.Body.Close() + } + return nil, err + } + + res, err = security.LimitResponseSize(res, secopts) if res != nil { defer res.Body.Close() } @@ -278,27 +79,26 @@ func download(ctx context.Context, imageURL string, opts DownloadOptions, secopt return nil, err } - body := res.Body - contentLength := int(res.ContentLength) - - if res.Header.Get("Content-Encoding") == "gzip" { - gzipBody, errGzip := gzip.NewReader(res.Body) - if gzipBody != nil { - defer gzipBody.Close() - } - if errGzip != nil { - return nil, err - } - body = gzipBody - contentLength = 0 - } - - imgdata, err := readAndCheckImage(body, contentLength, secopts) + imgdata, err := readAndCheckImage(res.Body, int(res.ContentLength), secopts) if err != nil { return nil, ierrors.Wrap(err, 0) } - imgdata.Headers = headersToStore(res) + h := make(map[string]string) + for k := range res.Header { + if !slices.Contains(keepResponseHeaders, k) { + continue + } + + // TODO: Fix Etag/ETag inconsistency + if k == "Etag" { + h["ETag"] = res.Header.Get(k) + } else { + h[k] = res.Header.Get(k) + } + } + + imgdata.Headers = h return imgdata, nil } diff --git a/imagedata/read.go b/imagedata/read.go index c509827f..988df8dd 100644 --- a/imagedata/read.go +++ b/imagedata/read.go @@ -8,6 +8,7 @@ import ( "github.com/imgproxy/imgproxy/v3/bufpool" "github.com/imgproxy/imgproxy/v3/bufreader" "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/imagefetcher" "github.com/imgproxy/imgproxy/v3/imagemeta" "github.com/imgproxy/imgproxy/v3/security" ) @@ -19,15 +20,9 @@ func initRead() { } func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options) (*ImageData, error) { - if err := security.CheckFileSize(contentLength, secopts); err != nil { - return nil, err - } - buf := downloadBufPool.Get(contentLength, false) cancel := func() { downloadBufPool.Put(buf) } - r = security.LimitFileSize(r, secopts) - br := bufreader.New(r, buf) meta, err := imagemeta.DecodeMeta(br) @@ -35,14 +30,14 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options) buf.Reset() cancel() - return nil, wrapError(err) + return nil, imagefetcher.WrapError(err) } if err = security.CheckDimensions(meta.Width(), meta.Height(), 1, secopts); err != nil { buf.Reset() cancel() - return nil, wrapError(err) + return nil, imagefetcher.WrapError(err) } downloadBufPool.GrowBuffer(buf, contentLength) @@ -51,7 +46,7 @@ func readAndCheckImage(r io.Reader, contentLength int, secopts security.Options) buf.Reset() cancel() - return nil, wrapError(err) + return nil, imagefetcher.WrapError(err) } return &ImageData{ diff --git a/imagedata/errors.go b/imagefetcher/errors.go similarity index 92% rename from imagedata/errors.go rename to imagefetcher/errors.go index c5e569b0..1d8eac9d 100644 --- a/imagedata/errors.go +++ b/imagefetcher/errors.go @@ -1,4 +1,4 @@ -package imagedata +package imagefetcher import ( "context" @@ -10,6 +10,8 @@ import ( "github.com/imgproxy/imgproxy/v3/security" ) +const msgSourceImageIsUnreachable = "Source image is unreachable" + type ( ImageRequestError struct{ error } ImageRequstSchemeError string @@ -20,7 +22,7 @@ type ( ImageRequestTimeoutError struct{ error } NotModifiedError struct { - headers map[string]string + headers http.Header } httpError interface { @@ -135,7 +137,7 @@ func (e ImageRequestTimeoutError) Error() string { func (e ImageRequestTimeoutError) Unwrap() error { return e.error } -func newNotModifiedError(headers map[string]string) error { +func newNotModifiedError(headers http.Header) error { return ierrors.Wrap( NotModifiedError{headers}, 1, @@ -147,11 +149,12 @@ func newNotModifiedError(headers map[string]string) error { func (e NotModifiedError) Error() string { return "Not modified" } -func (e NotModifiedError) Headers() map[string]string { +func (e NotModifiedError) Headers() http.Header { return e.headers } -func wrapError(err error) error { +// NOTE: make private when we remove download functions from imagedata package +func WrapError(err error) error { isTimeout := false var secArrdErr security.SourceAddressError diff --git a/imagefetcher/fetcher.go b/imagefetcher/fetcher.go new file mode 100644 index 00000000..1f46bd08 --- /dev/null +++ b/imagefetcher/fetcher.go @@ -0,0 +1,86 @@ +// imagefetcher is responsible for downloading images using HTTP requests through various protocols +// defined in transport package +package imagefetcher + +import ( + "context" + "net/http" + "time" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/transport" + "github.com/imgproxy/imgproxy/v3/transport/common" + "go.withmatt.com/httpheaders" +) + +const ( + connectionLostError = "client connection lost" // Error message indicating a lost connection + bounceDelay = 100 * time.Microsecond // Delay before retrying a request +) + +// Fetcher is a struct that holds the HTTP client and transport for fetching images +type Fetcher struct { + transport *transport.Transport // Transport used for making HTTP requests + maxRedirects int // Maximum number of redirects allowed +} + +// NewFetcher creates a new ImageFetcher with the provided transport +func NewFetcher(transport *transport.Transport, maxRedirects int) (*Fetcher, error) { + return &Fetcher{transport, maxRedirects}, nil +} + +// checkRedirect is a method that checks if the number of redirects exceeds the maximum allowed +func (f *Fetcher) checkRedirect(req *http.Request, via []*http.Request) error { + redirects := len(via) + if redirects >= f.maxRedirects { + return newImageTooManyRedirectsError(redirects) + } + return nil +} + +// newHttpClient returns new HTTP client +func (f *Fetcher) newHttpClient() *http.Client { + return &http.Client{ + Transport: f.transport.Transport(), // Connection pool is there + CheckRedirect: f.checkRedirect, + } +} + +// NewImageFetcherRequest creates a new ImageFetcherRequest with the provided context, URL, headers, and cookie jar +func (f *Fetcher) BuildRequest(ctx context.Context, url string, header http.Header, jar http.CookieJar) (*Request, error) { + url = common.EscapeURL(url) + + // Set request timeout and get cancel function + ctx, cancel := context.WithTimeout(ctx, time.Duration(config.DownloadTimeout)*time.Second) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + cancel() + return nil, newImageRequestError(err) + } + + // Check if the URL scheme is supported + if !f.transport.IsProtocolRegistered(req.URL.Scheme) { + cancel() + return nil, newImageRequstSchemeError(req.URL.Scheme) + } + + // Add cookies from the jar to the request (if any) + if jar != nil { + for _, cookie := range jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + + // Set user agent header + req.Header.Set(httpheaders.UserAgent, config.UserAgent) + + // Set headers + for k, v := range header { + if len(v) > 0 { + req.Header.Set(k, v[0]) + } + } + + return &Request{f, req, cancel}, nil +} diff --git a/imagefetcher/request.go b/imagefetcher/request.go new file mode 100644 index 00000000..db9b7cbe --- /dev/null +++ b/imagefetcher/request.go @@ -0,0 +1,204 @@ +package imagefetcher + +import ( + "compress/gzip" + "context" + "io" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "go.withmatt.com/httpheaders" +) + +var ( + // contentRangeRe Content-Range header regex to check if the response is a partial content response + contentRangeRe = regexp.MustCompile(`^bytes ((\d+)-(\d+)|\*)/(\d+|\*)$`) +) + +// Request is a struct that holds the request and cancel function for an image fetcher request +type Request struct { + fetcher *Fetcher // Parent ImageFetcher instance + request *http.Request // HTTP request to fetch the image + cancel context.CancelFunc // Request context cancel function +} + +// Send sends the generic request and returns the http.Response or an error +func (r *Request) Send() (*http.Response, error) { + client := r.fetcher.newHttpClient() + + // Let's add a cookie jar to the client if the request URL is HTTP or HTTPS + // This is necessary to pass cookie challenge for some servers. + if r.request.URL.Scheme == "http" || r.request.URL.Scheme == "https" { + jar, err := cookiejar.New(nil) + if err != nil { + return nil, err + } + client.Jar = jar + } + + for { + // Try request + res, err := client.Do(r.request) + if err == nil { + return res, nil // Return successful response + } + + // Close the response body if request was unsuccessful + if res != nil && res.Body != nil { + res.Body.Close() + } + + // Retry if the error is due to a lost connection + if strings.Contains(err.Error(), connectionLostError) { + select { + case <-r.request.Context().Done(): + return nil, err + case <-time.After(bounceDelay): + continue + } + } + + return nil, WrapError(err) + } +} + +// FetchImage fetches the image using the request and returns the response or an error. +// It checks for the NotModified status and handles partial content responses. +func (r *Request) FetchImage() (*http.Response, error) { + res, err := r.Send() + if err != nil { + r.cancel() + return nil, err + } + + // Closes the response body and cancels request context + cancel := func() { + res.Body.Close() + r.cancel() + } + + // If the source image was not modified, close the body and NotModifiedError + if res.StatusCode == http.StatusNotModified { + cancel() + return nil, newNotModifiedError(res.Header) + } + + // If the source responds with 206, check if the response contains an entire image. + // If not, return an error. + if res.StatusCode == http.StatusPartialContent { + err = checkPartialContentResponse(res) + if err != nil { + cancel() + return nil, err + } + } else if res.StatusCode != http.StatusOK { + body := extractErraticBody(res) + cancel() + return nil, newImageResponseStatusError(res.StatusCode, body) + } + + // If the response is gzip encoded, wrap it in a gzip reader + err = wrapGzipBody(res) + if err != nil { + cancel() + return nil, err + } + + // Wrap the response body in a bodyReader to ensure the request context + // is cancelled when the body is closed + res.Body = &bodyReader{ + body: res.Body, + request: r, + } + + return res, nil +} + +// Cancel cancels the request context +func (r *Request) Cancel() { + r.cancel() +} + +// URL returns the actual URL of the request +func (r *Request) URL() *url.URL { + return r.request.URL +} + +// checkPartialContentResponse if the response is a partial content response, +// we check if it contains the entire image. +func checkPartialContentResponse(res *http.Response) error { + contentRange := res.Header.Get(httpheaders.ContentRange) + rangeParts := contentRangeRe.FindStringSubmatch(contentRange) + + if len(rangeParts) == 0 { + return newImagePartialResponseError("Partial response with invalid Content-Range header") + } + + if rangeParts[1] == "*" || rangeParts[2] != "0" { + return newImagePartialResponseError("Partial response with incomplete content") + } + + contentLengthStr := rangeParts[4] + if contentLengthStr == "*" { + contentLengthStr = res.Header.Get(httpheaders.ContentLength) + } + + contentLength, _ := strconv.Atoi(contentLengthStr) + rangeEnd, _ := strconv.Atoi(rangeParts[3]) + + if contentLength <= 0 || rangeEnd != contentLength-1 { + return newImagePartialResponseError("Partial response with incomplete content") + } + + return nil +} + +// extractErraticBody extracts the error body from the response if it is a text-based content type +func extractErraticBody(res *http.Response) string { + if strings.HasPrefix(res.Header.Get(httpheaders.ContentType), "text/") { + bbody, _ := io.ReadAll(io.LimitReader(res.Body, 1024)) + return string(bbody) + } + + return "" +} + +// wrapGzipBody wraps the response body in a gzip reader if the Content-Encoding is gzip. +// We set DisableCompression: true to avoid sending the Accept-Encoding: gzip header, +// since we do not want to compress image data (which is usually already compressed). +// However, some servers still send gzip-encoded responses regardless. +func wrapGzipBody(res *http.Response) error { + if res.Header.Get(httpheaders.ContentEncoding) == "gzip" { + gzipBody, err := gzip.NewReader(res.Body) + if err != nil { + return nil + } + res.Body = gzipBody + res.Header.Del(httpheaders.ContentEncoding) + } + + return nil +} + +// bodyReader is a wrapper around io.ReadCloser which closes original request context +// when the body is closed. +type bodyReader struct { + body io.ReadCloser // The body to read from + request *Request +} + +// Read reads data from the response body into the provided byte slice +func (r *bodyReader) Read(p []byte) (int, error) { + return r.body.Read(p) +} + +// Close closes the response body and cancels the request context +func (r *bodyReader) Close() error { + defer r.request.cancel() + return r.body.Close() +} diff --git a/processing_handler.go b/processing_handler.go index 94fc805d..32ab014b 100644 --- a/processing_handler.go +++ b/processing_handler.go @@ -20,6 +20,7 @@ import ( "github.com/imgproxy/imgproxy/v3/etag" "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/imagedata" + "github.com/imgproxy/imgproxy/v3/imagefetcher" "github.com/imgproxy/imgproxy/v3/imagetype" "github.com/imgproxy/imgproxy/v3/imath" "github.com/imgproxy/imgproxy/v3/metrics" @@ -348,7 +349,7 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { return imagedata.Download(ctx, imageURL, "source image", downloadOpts, po.SecurityOptions) }() - var nmErr imagedata.NotModifiedError + var nmErr imagefetcher.NotModifiedError switch { case err == nil: @@ -358,7 +359,13 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) { if config.ETagEnabled && len(etagHandler.ImageEtagExpected()) != 0 { rw.Header().Set("ETag", etagHandler.GenerateExpectedETag()) } - respondWithNotModified(reqID, r, rw, po, imageURL, nmErr.Headers()) + + h := make(map[string]string) + for k := range nmErr.Headers() { + h[k] = nmErr.Headers().Get(k) + } + + respondWithNotModified(reqID, r, rw, po, imageURL, h) return default: diff --git a/security/file_size.go b/security/file_size.go deleted file mode 100644 index 44dc5f9a..00000000 --- a/security/file_size.go +++ /dev/null @@ -1,38 +0,0 @@ -package security - -import ( - "io" -) - -type hardLimitReader struct { - r io.Reader - left int -} - -func (lr *hardLimitReader) Read(p []byte) (n int, err error) { - if lr.left <= 0 { - return 0, newFileSizeError() - } - if len(p) > lr.left { - p = p[0:lr.left] - } - n, err = lr.r.Read(p) - lr.left -= n - return -} - -func CheckFileSize(size int, opts Options) error { - if opts.MaxSrcFileSize > 0 && size > opts.MaxSrcFileSize { - return newFileSizeError() - } - - return nil -} - -func LimitFileSize(r io.Reader, opts Options) io.Reader { - if opts.MaxSrcFileSize > 0 { - return &hardLimitReader{r: r, left: opts.MaxSrcFileSize} - } - - return r -} diff --git a/security/response_limit.go b/security/response_limit.go new file mode 100644 index 00000000..ffcc0675 --- /dev/null +++ b/security/response_limit.go @@ -0,0 +1,51 @@ +package security + +import ( + "io" + "net/http" +) + +// hardLimitReadCloser is a wrapper around io.ReadCloser +// that limits the number of bytes it can read from the upstream reader. +type hardLimitReadCloser struct { + r io.ReadCloser + left int +} + +func (lr *hardLimitReadCloser) Read(p []byte) (n int, err error) { + if lr.left <= 0 { + return 0, newFileSizeError() + } + if len(p) > lr.left { + p = p[0:lr.left] + } + n, err = lr.r.Read(p) + lr.left -= n + return +} + +func (lr *hardLimitReadCloser) Close() error { + return lr.r.Close() +} + +// LimitResponseSize limits the size of the response body to MaxSrcFileSize (if set). +// First, it tries to use Content-Length header to check the limit. +// If Content-Length is not set, it limits the size of the response body by wrapping +// body reader with hard limit reader. +func LimitResponseSize(r *http.Response, opts Options) (*http.Response, error) { + if opts.MaxSrcFileSize == 0 { + return r, nil + } + + // If Content-Length was set, limit the size of the response body before reading it + size := int(r.ContentLength) + + if size > opts.MaxSrcFileSize { + return nil, newFileSizeError() + } + + // hard-limit the response body reader + r.Body = &hardLimitReadCloser{r: r.Body, left: opts.MaxSrcFileSize} + + return r, nil +} diff --git a/stream.go b/stream.go index 29c59c35..2b280f35 100644 --- a/stream.go +++ b/stream.go @@ -69,11 +69,11 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht checkErr(ctx, "streaming", err) } - req, reqCancel, err := imagedata.BuildImageRequest(r.Context(), imageURL, imgRequestHeader, cookieJar) - defer reqCancel() + req, err := imagedata.Fetcher.BuildRequest(r.Context(), imageURL, imgRequestHeader, cookieJar) + defer req.Cancel() checkErr(ctx, "streaming", err) - res, err := imagedata.SendRequest(req) + res, err := req.Send() if res != nil { defer res.Body.Close() } @@ -93,7 +93,7 @@ func streamOriginImage(ctx context.Context, reqID string, r *http.Request, rw ht if res.StatusCode < 300 { var filename, ext, mimetype string - _, filename = filepath.Split(req.URL.Path) + _, filename = filepath.Split(req.URL().Path) ext = filepath.Ext(filename) if len(po.Filename) > 0 { diff --git a/transport/azure/azure.go b/transport/azure/azure.go index 92acfb6d..7c308317 100644 --- a/transport/azure/azure.go +++ b/transport/azure/azure.go @@ -18,8 +18,8 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/httprange" - defaultTransport "github.com/imgproxy/imgproxy/v3/transport" "github.com/imgproxy/imgproxy/v3/transport/common" + "github.com/imgproxy/imgproxy/v3/transport/generichttp" "github.com/imgproxy/imgproxy/v3/transport/notmodified" ) @@ -49,7 +49,7 @@ func New() (http.RoundTripper, error) { return nil, err } - trans, err := defaultTransport.New(false) + trans, err := generichttp.New(false) if err != nil { return nil, err } diff --git a/transport/gcs/gcs.go b/transport/gcs/gcs.go index 2f3bdc4e..88afe607 100644 --- a/transport/gcs/gcs.go +++ b/transport/gcs/gcs.go @@ -17,8 +17,8 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/httprange" "github.com/imgproxy/imgproxy/v3/ierrors" - defaultTransport "github.com/imgproxy/imgproxy/v3/transport" "github.com/imgproxy/imgproxy/v3/transport/common" + "github.com/imgproxy/imgproxy/v3/transport/generichttp" "github.com/imgproxy/imgproxy/v3/transport/notmodified" ) @@ -30,7 +30,7 @@ type transport struct { } func buildHTTPClient(opts ...option.ClientOption) (*http.Client, error) { - trans, err := defaultTransport.New(false) + trans, err := generichttp.New(false) if err != nil { return nil, err } diff --git a/transport/generichttp/generic_http.go b/transport/generichttp/generic_http.go new file mode 100644 index 00000000..ce960692 --- /dev/null +++ b/transport/generichttp/generic_http.go @@ -0,0 +1,59 @@ +// Generic HTTP transport for imgproxy +package generichttp + +import ( + "crypto/tls" + "net" + "net/http" + "syscall" + "time" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/security" + "golang.org/x/net/http2" +) + +func New(verifyNetworks bool) (*http.Transport, error) { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + } + + if verifyNetworks { + dialer.Control = func(network, address string, c syscall.RawConn) error { + return security.VerifySourceNetwork(address) + } + } + + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: config.Workers + 1, + IdleConnTimeout: time.Duration(config.ClientKeepAliveTimeout) * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: false, + DisableCompression: true, + } + + if config.ClientKeepAliveTimeout <= 0 { + transport.MaxIdleConnsPerHost = -1 + transport.DisableKeepAlives = true + } + + if config.IgnoreSslVerification { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + transport2, err := http2.ConfigureTransports(transport) + if err != nil { + return nil, err + } + + transport2.PingTimeout = 5 * time.Second + transport2.ReadIdleTimeout = time.Second + + return transport, nil +} diff --git a/transport/s3/s3.go b/transport/s3/s3.go index 3aa6d8dd..0eee8a5b 100644 --- a/transport/s3/s3.go +++ b/transport/s3/s3.go @@ -22,8 +22,8 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/ierrors" - defaultTransport "github.com/imgproxy/imgproxy/v3/transport" "github.com/imgproxy/imgproxy/v3/transport/common" + "github.com/imgproxy/imgproxy/v3/transport/generichttp" ) type s3Client interface { @@ -49,7 +49,7 @@ func New() (http.RoundTripper, error) { return nil, ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load AWS S3 config")) } - trans, err := defaultTransport.New(false) + trans, err := generichttp.New(false) if err != nil { return nil, err } diff --git a/transport/swift/swift.go b/transport/swift/swift.go index 8657753c..8fd4c8fb 100644 --- a/transport/swift/swift.go +++ b/transport/swift/swift.go @@ -12,8 +12,8 @@ import ( "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/ierrors" - defaultTransport "github.com/imgproxy/imgproxy/v3/transport" "github.com/imgproxy/imgproxy/v3/transport/common" + "github.com/imgproxy/imgproxy/v3/transport/generichttp" "github.com/imgproxy/imgproxy/v3/transport/notmodified" ) @@ -22,7 +22,7 @@ type transport struct { } func New() (http.RoundTripper, error) { - trans, err := defaultTransport.New(false) + trans, err := generichttp.New(false) if err != nil { return nil, err } diff --git a/transport/transport.go b/transport/transport.go index cd60f749..64a33035 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -1,59 +1,106 @@ +// Package transport provides a custom HTTP transport that supports multiple protocols +// such as S3, GCS, ABS, Swift, and local file system. package transport import ( - "crypto/tls" - "net" "net/http" - "syscall" - "time" - - "golang.org/x/net/http2" "github.com/imgproxy/imgproxy/v3/config" - "github.com/imgproxy/imgproxy/v3/security" + "github.com/imgproxy/imgproxy/v3/transport/generichttp" + + azureTransport "github.com/imgproxy/imgproxy/v3/transport/azure" + fsTransport "github.com/imgproxy/imgproxy/v3/transport/fs" + gcsTransport "github.com/imgproxy/imgproxy/v3/transport/gcs" + s3Transport "github.com/imgproxy/imgproxy/v3/transport/s3" + swiftTransport "github.com/imgproxy/imgproxy/v3/transport/swift" ) -func New(verifyNetworks bool) (*http.Transport, error) { - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - } +// Transport is a wrapper around http.Transport which allows to track registered protocols +type Transport struct { + transport *http.Transport + schemes map[string]struct{} +} - if verifyNetworks { - dialer.Control = func(network, address string, c syscall.RawConn) error { - return security.VerifySourceNetwork(address) - } - } - - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: dialer.DialContext, - MaxIdleConns: 100, - MaxIdleConnsPerHost: config.Workers + 1, - IdleConnTimeout: time.Duration(config.ClientKeepAliveTimeout) * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ForceAttemptHTTP2: false, - DisableCompression: true, - } - - if config.ClientKeepAliveTimeout <= 0 { - transport.MaxIdleConnsPerHost = -1 - transport.DisableKeepAlives = true - } - - if config.IgnoreSslVerification { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - transport2, err := http2.ConfigureTransports(transport) +// NewTransport creates a new HTTP transport with no protocols registered +func NewTransport() (*Transport, error) { + transport, err := generichttp.New(true) if err != nil { return nil, err } - transport2.PingTimeout = 5 * time.Second - transport2.ReadIdleTimeout = time.Second + // http and https are always registered + schemes := map[string]struct{}{ + "http": {}, + "https": {}, + } - return transport, nil + t := &Transport{ + transport, + schemes, + } + + err = t.registerAllProtocols() + if err != nil { + return nil, err + } + + return t, nil +} + +// Transport returns the underlying http.Transport +func (t *Transport) Transport() *http.Transport { + return t.transport +} + +// RegisterProtocol registers a new transport protocol with the transport +func (t *Transport) RegisterProtocol(scheme string, rt http.RoundTripper) { + t.transport.RegisterProtocol(scheme, rt) + t.schemes[scheme] = struct{}{} +} + +// IsProtocolRegistered checks if a protocol is registered in the transport +func (t *Transport) IsProtocolRegistered(scheme string) bool { + _, ok := t.schemes[scheme] + return ok +} + +// RegisterAllProtocols registers all enabled protocols in the given transport +func (t *Transport) registerAllProtocols() error { + if config.LocalFileSystemRoot != "" { + t.RegisterProtocol("local", fsTransport.New()) + } + + if config.S3Enabled { + if tr, err := s3Transport.New(); err != nil { + return err + } else { + t.RegisterProtocol("s3", tr) + } + } + + if config.GCSEnabled { + if tr, err := gcsTransport.New(); err != nil { + return err + } else { + t.RegisterProtocol("gs", tr) + } + } + + if config.ABSEnabled { + if tr, err := azureTransport.New(); err != nil { + return err + } else { + t.RegisterProtocol("abs", tr) + } + } + + if config.SwiftEnabled { + if tr, err := swiftTransport.New(); err != nil { + return err + } else { + t.RegisterProtocol("swift", tr) + } + } + + return nil }