From 8bc70491fb210eeaa0cbb7e7166b2f15bfcf13e6 Mon Sep 17 00:00:00 2001 From: Viktor Sokolov Date: Tue, 26 Aug 2025 16:19:41 +0200 Subject: [PATCH] processing_handler.go -> handlers/processing --- auximageprovider/provider.go | 18 + auximageprovider/static_config.go | 37 ++ auximageprovider/static_provider.go | 52 +++ auximageprovider/static_provider_test.go | 200 ++++++++++ fix_path.go | 22 -- handlers/processing/config.go | 59 +++ errors.go => handlers/processing/errors.go | 30 +- handlers/processing/handler.go | 143 +++++++ handlers/processing/path.go | 56 +++ handlers/processing/path_test.go | 180 +++++++++ handlers/processing/request.go | 140 +++++++ handlers/processing/request_methods.go | 278 ++++++++++++++ handlers/stream/config.go | 2 +- handlers/stream/handler.go | 17 +- headerwriter/config.go | 2 +- headerwriter/writer.go | 37 +- headerwriter/writer_test.go | 11 +- imagedata/download.go | 2 +- imagedata/image_data.go | 36 +- imagefetcher/config.go | 2 +- main.go | 76 +++- processing_handler.go | 425 --------------------- processing_handler_test.go | 5 + security/signature.go | 14 +- security/signature_test.go | 10 +- semaphores/config.go | 43 +++ semaphores/errors.go | 21 + semaphores/semaphores.go | 65 ++++ semaphores/semaphores_test.go | 50 +++ server/config.go | 2 +- 30 files changed, 1489 insertions(+), 546 deletions(-) create mode 100644 auximageprovider/provider.go create mode 100644 auximageprovider/static_config.go create mode 100644 auximageprovider/static_provider.go create mode 100644 auximageprovider/static_provider_test.go delete mode 100644 fix_path.go create mode 100644 handlers/processing/config.go rename errors.go => handlers/processing/errors.go (61%) create mode 100644 handlers/processing/handler.go create mode 100644 handlers/processing/path.go create mode 100644 handlers/processing/path_test.go create mode 100644 handlers/processing/request.go create mode 100644 handlers/processing/request_methods.go delete mode 100644 processing_handler.go create mode 100644 semaphores/config.go create mode 100644 semaphores/errors.go create mode 100644 semaphores/semaphores.go create mode 100644 semaphores/semaphores_test.go diff --git a/auximageprovider/provider.go b/auximageprovider/provider.go new file mode 100644 index 00000000..87de72f1 --- /dev/null +++ b/auximageprovider/provider.go @@ -0,0 +1,18 @@ +// auximagedata exposes an interface for retreiving auxiliary images +// such as watermarks and fallbacks. Default implementation stores those in memory. + +package auximageprovider + +import ( + "context" + "net/http" + + "github.com/imgproxy/imgproxy/v3/imagedata" + "github.com/imgproxy/imgproxy/v3/options" +) + +// Provider is an interface that provides image data and headers based +// on processing options. It is used to retrieve WatermarkImage and FallbackImage. +type Provider interface { + Get(context.Context, *options.ProcessingOptions) (imagedata.ImageData, http.Header, error) +} diff --git a/auximageprovider/static_config.go b/auximageprovider/static_config.go new file mode 100644 index 00000000..1bc1ba48 --- /dev/null +++ b/auximageprovider/static_config.go @@ -0,0 +1,37 @@ +package auximageprovider + +import "github.com/imgproxy/imgproxy/v3/config" + +// StaticConfig holds the configuration for the auxiliary image provider +type StaticConfig struct { + Base64Data string + Path string + URL string +} + +// NewDefaultStaticConfig creates a new default configuration for the auxiliary image provider +func NewDefaultStaticConfig() *StaticConfig { + return &StaticConfig{ + Base64Data: "", + Path: "", + URL: "", + } +} + +// LoadWatermarkStaticConfigFromEnv loads the watermark configuration from the environment +func LoadWatermarkStaticConfigFromEnv(c *StaticConfig) (*StaticConfig, error) { + c.Base64Data = config.WatermarkData + c.Path = config.WatermarkPath + c.URL = config.WatermarkURL + + return c, nil +} + +// LoadFallbackStaticConfigFromEnv loads the fallback configuration from the environment +func LoadFallbackStaticConfigFromEnv(c *StaticConfig) (*StaticConfig, error) { + c.Base64Data = config.FallbackImageData + c.Path = config.FallbackImagePath + c.URL = config.FallbackImageURL + + return c, nil +} diff --git a/auximageprovider/static_provider.go b/auximageprovider/static_provider.go new file mode 100644 index 00000000..353abf18 --- /dev/null +++ b/auximageprovider/static_provider.go @@ -0,0 +1,52 @@ +package auximageprovider + +import ( + "context" + "net/http" + + "github.com/imgproxy/imgproxy/v3/imagedata" + "github.com/imgproxy/imgproxy/v3/options" +) + +// staticProvider is a simple implementation of ImageProvider, which returns +// a static saved image data and headers. +type staticProvider struct { + data imagedata.ImageData + headers http.Header +} + +// Get returns the static image data and headers stored in the provider. +func (s *staticProvider) Get(_ context.Context, po *options.ProcessingOptions) (imagedata.ImageData, http.Header, error) { + return s.data, s.headers.Clone(), nil +} + +// NewStaticFromTriple creates a new ImageProvider from either a base64 string, file path, or URL +func NewStaticProvider(ctx context.Context, c *StaticConfig, desc string) (Provider, error) { + var ( + data imagedata.ImageData + headers = make(http.Header) + err error + ) + + switch { + case len(c.Base64Data) > 0: + data, err = imagedata.NewFromBase64(c.Base64Data) + case len(c.Path) > 0: + data, err = imagedata.NewFromPath(c.Path) + case len(c.URL) > 0: + data, headers, err = imagedata.DownloadSync( + ctx, c.URL, desc, imagedata.DownloadOptions{}, + ) + default: + return nil, nil + } + + if err != nil { + return nil, err + } + + return &staticProvider{ + data: data, + headers: headers, + }, nil +} diff --git a/auximageprovider/static_provider_test.go b/auximageprovider/static_provider_test.go new file mode 100644 index 00000000..fef9d4b1 --- /dev/null +++ b/auximageprovider/static_provider_test.go @@ -0,0 +1,200 @@ +package auximageprovider + +import ( + "encoding/base64" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/imagedata" + "github.com/imgproxy/imgproxy/v3/options" +) + +type ImageProviderTestSuite struct { + suite.Suite + + server *httptest.Server + testData []byte + testDataB64 string + + // Server state + status int + data []byte + header http.Header +} + +func (s *ImageProviderTestSuite) SetupSuite() { + config.Reset() + config.AllowLoopbackSourceAddresses = true + + // Load test image data + f, err := os.Open("../testdata/test1.jpg") + s.Require().NoError(err) + defer f.Close() + + data, err := io.ReadAll(f) + s.Require().NoError(err) + + s.testData = data + s.testDataB64 = base64.StdEncoding.EncodeToString(data) + + // Create test server + s.server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + for k, vv := range s.header { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + + data := s.data + if data == nil { + data = s.testData + } + + rw.Header().Set(httpheaders.ContentLength, strconv.Itoa(len(data))) + rw.WriteHeader(s.status) + rw.Write(data) + })) + + s.Require().NoError(imagedata.Init()) +} + +func (s *ImageProviderTestSuite) TearDownSuite() { + s.server.Close() +} + +func (s *ImageProviderTestSuite) SetupTest() { + s.status = http.StatusOK + s.data = nil + s.header = http.Header{} + s.header.Set(httpheaders.ContentType, "image/jpeg") +} + +// Helper function to read data from ImageData +func (s *ImageProviderTestSuite) readImageData(provider Provider) []byte { + imgData, _, err := provider.Get(s.T().Context(), &options.ProcessingOptions{}) + s.Require().NoError(err) + s.Require().NotNil(imgData) + defer imgData.Close() + + reader := imgData.Reader() + data, err := io.ReadAll(reader) + s.Require().NoError(err) + return data +} + +func (s *ImageProviderTestSuite) TestNewProvider() { + tests := []struct { + name string + config *StaticConfig + setupFunc func() + expectError bool + expectNil bool + validateFunc func(provider Provider) + }{ + { + name: "B64", + config: &StaticConfig{Base64Data: s.testDataB64}, + validateFunc: func(provider Provider) { + s.Equal(s.testData, s.readImageData(provider)) + }, + }, + { + name: "Path", + config: &StaticConfig{Path: "../testdata/test1.jpg"}, + validateFunc: func(provider Provider) { + s.Equal(s.testData, s.readImageData(provider)) + }, + }, + { + name: "URL", + config: &StaticConfig{URL: s.server.URL}, + validateFunc: func(provider Provider) { + s.Equal(s.testData, s.readImageData(provider)) + }, + }, + { + name: "EmptyConfig", + config: &StaticConfig{}, + expectNil: true, + }, + { + name: "InvalidURL", + config: &StaticConfig{URL: "http://invalid-url-that-does-not-exist.invalid"}, + expectError: true, + expectNil: true, + }, + { + name: "InvalidBase64", + config: &StaticConfig{Base64Data: "invalid-base64-data!!!"}, + expectError: true, + expectNil: true, + }, + { + name: "Base64PreferenceOverPath", + config: &StaticConfig{ + Base64Data: base64.StdEncoding.EncodeToString(s.testData), + Path: "../testdata/test2.jpg", // This should be ignored + }, + validateFunc: func(provider Provider) { + actualData := s.readImageData(provider) + s.Equal(s.testData, actualData) + }, + }, + { + name: "HeadersPassedThrough", + config: &StaticConfig{URL: s.server.URL}, + setupFunc: func() { + s.header.Set("X-Custom-Header", "test-value") + s.header.Set(httpheaders.CacheControl, "max-age=3600") + }, + validateFunc: func(provider Provider) { + imgData, headers, err := provider.Get(s.T().Context(), &options.ProcessingOptions{}) + s.Require().NoError(err) + s.Require().NotNil(imgData) + defer imgData.Close() + + s.Equal("test-value", headers.Get("X-Custom-Header")) + s.Equal("max-age=3600", headers.Get(httpheaders.CacheControl)) + s.Equal("image/jpeg", headers.Get(httpheaders.ContentType)) + }, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + if tt.setupFunc != nil { + tt.setupFunc() + } + + provider, err := NewStaticProvider(s.T().Context(), tt.config, "test image") + + if tt.expectError { + s.Require().Error(err) + } else { + s.Require().NoError(err) + } + + if tt.expectNil { + s.Nil(provider) + } else { + s.Require().NotNil(provider) + } + + if tt.validateFunc != nil { + tt.validateFunc(provider) + } + }) + } +} + +func TestImageProvider(t *testing.T) { + suite.Run(t, new(ImageProviderTestSuite)) +} diff --git a/fix_path.go b/fix_path.go deleted file mode 100644 index ee04183f..00000000 --- a/fix_path.go +++ /dev/null @@ -1,22 +0,0 @@ -package main - -import ( - "fmt" - "regexp" - "strings" -) - -var fixPathRe = regexp.MustCompile(`/plain/(\S+)\:/([^/])`) - -func fixPath(path string) string { - for _, match := range fixPathRe.FindAllStringSubmatch(path, -1) { - repl := fmt.Sprintf("/plain/%s://", match[1]) - if match[1] == "local" { - repl += "/" - } - repl += match[2] - path = strings.Replace(path, match[0], repl, 1) - } - - return path -} diff --git a/handlers/processing/config.go b/handlers/processing/config.go new file mode 100644 index 00000000..493368fe --- /dev/null +++ b/handlers/processing/config.go @@ -0,0 +1,59 @@ +package processing + +import ( + "errors" + "net/http" + + "github.com/imgproxy/imgproxy/v3/config" +) + +// Config represents handler config +type Config struct { + PathPrefix string // Route path prefix + CookiePassthrough bool // Whether to passthrough cookies + ReportDownloadingErrors bool // Whether to report downloading errors + LastModifiedEnabled bool // Whether to enable Last-Modified + ETagEnabled bool // Whether to enable ETag + ReportIOErrors bool // Whether to report IO errors + FallbackImageHTTPCode int // Fallback image HTTP status code + EnableDebugHeaders bool // Whether to enable debug headers + FallbackImageData string // Fallback image data (base64) + FallbackImagePath string // Fallback image path (local file system) + FallbackImageURL string // Fallback image URL (remote) +} + +// NewDefaultConfig creates a new configuration with defaults +func NewDefaultConfig() *Config { + return &Config{ + PathPrefix: "", + CookiePassthrough: false, + ReportDownloadingErrors: true, + LastModifiedEnabled: true, + ETagEnabled: true, + ReportIOErrors: false, + FallbackImageHTTPCode: http.StatusOK, + EnableDebugHeaders: false, + } +} + +// LoadFromEnv loads config from environment variables +func LoadFromEnv(c *Config) (*Config, error) { + c.PathPrefix = config.PathPrefix + c.CookiePassthrough = config.CookiePassthrough + c.ReportDownloadingErrors = config.ReportDownloadingErrors + c.LastModifiedEnabled = config.LastModifiedEnabled + c.ETagEnabled = config.ETagEnabled + c.ReportIOErrors = config.ReportIOErrors + c.FallbackImageHTTPCode = config.FallbackImageHTTPCode + c.EnableDebugHeaders = config.EnableDebugHeaders + + return c, nil +} + +// Validate checks configuration values +func (c *Config) Validate() error { + if c.FallbackImageHTTPCode != 0 && (c.FallbackImageHTTPCode < 100 || c.FallbackImageHTTPCode > 599) { + return errors.New("fallback image HTTP code should be between 100 and 599") + } + return nil +} diff --git a/errors.go b/handlers/processing/errors.go similarity index 61% rename from errors.go rename to handlers/processing/errors.go index 7d660ccf..bf58d0d8 100644 --- a/errors.go +++ b/handlers/processing/errors.go @@ -1,10 +1,11 @@ -package main +package processing import ( "fmt" "net/http" "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/imagetype" ) // Monitoring error categories @@ -21,9 +22,8 @@ const ( ) type ( - ResponseWriteError struct{ error } - InvalidURLError string - TooManyRequestsError struct{} + ResponseWriteError struct{ error } + InvalidURLError string ) func newResponseWriteError(cause error) *ierrors.Error { @@ -54,14 +54,18 @@ func newInvalidURLErrorf(status int, format string, args ...interface{}) error { func (e InvalidURLError) Error() string { return string(e) } -func newTooManyRequestsError() error { - return ierrors.Wrap( - TooManyRequestsError{}, - 1, - ierrors.WithStatusCode(http.StatusTooManyRequests), - ierrors.WithPublicMessage("Too many requests"), - ierrors.WithShouldReport(false), - ) +// newCantSaveError creates "resulting image not supported" error +func newCantSaveError(format imagetype.Type) error { + return ierrors.Wrap(newInvalidURLErrorf( + http.StatusUnprocessableEntity, + "Resulting image format is not supported: %s", format, + ), 1, ierrors.WithCategory(categoryPathParsing)) } -func (e TooManyRequestsError) Error() string { return "Too many requests" } +// newCantLoadError creates "source image not supported" error +func newCantLoadError(format imagetype.Type) error { + return ierrors.Wrap(newInvalidURLErrorf( + http.StatusUnprocessableEntity, + "Source image format is not supported: %s", format, + ), 1, ierrors.WithCategory(categoryProcessing)) +} diff --git a/handlers/processing/handler.go b/handlers/processing/handler.go new file mode 100644 index 00000000..95d37a0d --- /dev/null +++ b/handlers/processing/handler.go @@ -0,0 +1,143 @@ +package processing + +import ( + "context" + "net/http" + "net/url" + + "github.com/imgproxy/imgproxy/v3/auximageprovider" + "github.com/imgproxy/imgproxy/v3/errorreport" + "github.com/imgproxy/imgproxy/v3/handlers/stream" + "github.com/imgproxy/imgproxy/v3/headerwriter" + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/monitoring" + "github.com/imgproxy/imgproxy/v3/monitoring/stats" + "github.com/imgproxy/imgproxy/v3/options" + "github.com/imgproxy/imgproxy/v3/security" + "github.com/imgproxy/imgproxy/v3/semaphores" +) + +// Handler handles image processing requests +type Handler struct { + hw *headerwriter.Writer // Configured HeaderWriter instance + stream *stream.Handler // Stream handler for raw image streaming + config *Config // Handler configuration + semaphores *semaphores.Semaphores + fallbackImage auximageprovider.Provider +} + +// New creates new handler object +func New( + stream *stream.Handler, + hw *headerwriter.Writer, + semaphores *semaphores.Semaphores, + fi auximageprovider.Provider, + config *Config, +) (*Handler, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &Handler{ + hw: hw, + config: config, + stream: stream, + semaphores: semaphores, + fallbackImage: fi, + }, nil +} + +// Execute handles the image processing request +func (h *Handler) Execute( + reqID string, + rw http.ResponseWriter, + imageRequest *http.Request, +) error { + // Increment the number of requests in progress + stats.IncRequestsInProgress() + defer stats.DecRequestsInProgress() + + ctx := imageRequest.Context() + + // Verify URL signature and extract image url and processing options + imageURL, po, mm, err := h.newRequest(ctx, imageRequest) + if err != nil { + return err + } + + // if processing options indicate raw image streaming, stream it and return + if po.Raw { + return h.stream.Execute(ctx, imageRequest, imageURL, reqID, po, rw) + } + + req := &request{ + handler: h, + imageRequest: imageRequest, + reqID: reqID, + rw: rw, + config: h.config, + po: po, + imageURL: imageURL, + monitoringMeta: mm, + semaphores: h.semaphores, + hwr: h.hw.NewRequest(), + } + + return req.execute(ctx) +} + +// newRequest extracts image url and processing options from request URL and verifies them +func (h *Handler) newRequest( + ctx context.Context, + imageRequest *http.Request, +) (string, *options.ProcessingOptions, monitoring.Meta, error) { + // let's extract signature and valid request path from a request + path, signature, err := splitPathSignature(imageRequest, h.config) + if err != nil { + return "", nil, nil, err + } + + // verify the signature (if any) + if err = security.VerifySignature(signature, path); err != nil { + return "", nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) + } + + // parse image url and processing options + po, imageURL, err := options.ParsePath(path, imageRequest.Header) + if err != nil { + return "", nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing)) + } + + // get image origin and create monitoring meta object + imageOrigin := imageOrigin(imageURL) + + mm := monitoring.Meta{ + monitoring.MetaSourceImageURL: imageURL, + monitoring.MetaSourceImageOrigin: imageOrigin, + monitoring.MetaProcessingOptions: po.Diff().Flatten(), + } + + // set error reporting and monitoring context + errorreport.SetMetadata(imageRequest, "Source Image URL", imageURL) + errorreport.SetMetadata(imageRequest, "Source Image Origin", imageOrigin) + errorreport.SetMetadata(imageRequest, "Processing Options", po) + + monitoring.SetMetadata(ctx, mm) + + // verify that image URL came from the valid source + err = security.VerifySourceURL(imageURL) + if err != nil { + return "", nil, mm, ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) + } + + return imageURL, po, mm, nil +} + +// imageOrigin extracts image origin from URL +func imageOrigin(imageURL string) string { + if u, uerr := url.Parse(imageURL); uerr == nil { + return u.Scheme + "://" + u.Host + } + + return "" +} diff --git a/handlers/processing/path.go b/handlers/processing/path.go new file mode 100644 index 00000000..be6c5d64 --- /dev/null +++ b/handlers/processing/path.go @@ -0,0 +1,56 @@ +package processing + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/imgproxy/imgproxy/v3/ierrors" +) + +// fixPathRe is used in path re-denormalization +var fixPathRe = regexp.MustCompile(`/plain/(\S+)\:/([^/])`) + +// splitPathSignature splits signature and path components from the request URI +func splitPathSignature(r *http.Request, config *Config) (string, string, error) { + uri := r.RequestURI + + // cut query params + uri, _, _ = strings.Cut(uri, "?") + + // cut path prefix + if len(config.PathPrefix) > 0 { + uri = strings.TrimPrefix(uri, config.PathPrefix) + } + + // cut leading slash + uri = strings.TrimPrefix(uri, "/") + + signature, path, _ := strings.Cut(uri, "/") + if len(signature) == 0 || len(path) == 0 { + return "", "", ierrors.Wrap( + newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0, + ierrors.WithCategory(categoryPathParsing), + ) + } + + // restore broken slashes in the path + path = redenormalizePath(path) + + return path, signature, nil +} + +// redenormalizePath undoes path normalization done by some browsers and revers proxies +func redenormalizePath(path string) string { + for _, match := range fixPathRe.FindAllStringSubmatch(path, -1) { + repl := fmt.Sprintf("/plain/%s://", match[1]) + if match[1] == "local" { + repl += "/" + } + repl += match[2] + path = strings.Replace(path, match[0], repl, 1) + } + + return path +} diff --git a/handlers/processing/path_test.go b/handlers/processing/path_test.go new file mode 100644 index 00000000..12e19f60 --- /dev/null +++ b/handlers/processing/path_test.go @@ -0,0 +1,180 @@ +package processing + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/stretchr/testify/suite" +) + +type PathTestSuite struct { + suite.Suite +} + +func TestPathTestSuite(t *testing.T) { + suite.Run(t, new(PathTestSuite)) +} + +func (s *PathTestSuite) createRequest(path string) *http.Request { + return httptest.NewRequest("GET", path, nil) +} + +func (s *PathTestSuite) TestParsePath() { + testCases := []struct { + name string + pathPrefix string + requestPath string + expectedPath string + expectedSig string + expectedError bool + }{ + { + name: "BasicPath", + requestPath: "/dummy_signature/rs:fill:300:200/plain/http://example.com/image.jpg", + expectedPath: "rs:fill:300:200/plain/http://example.com/image.jpg", + expectedSig: "dummy_signature", + expectedError: false, + }, + { + name: "PathWithQueryParams", + requestPath: "/dummy_signature/rs:fill:300:200/plain/http://example.com/image.jpg?param1=value1¶m2=value2", + expectedPath: "rs:fill:300:200/plain/http://example.com/image.jpg", + expectedSig: "dummy_signature", + expectedError: false, + }, + { + name: "PathWithPrefix", + pathPrefix: "/imgproxy", + requestPath: "/imgproxy/dummy_signature/rs:fill:300:200/plain/http://example.com/image.jpg", + expectedPath: "rs:fill:300:200/plain/http://example.com/image.jpg", + expectedSig: "dummy_signature", + expectedError: false, + }, + { + name: "PathWithRedenormalization", + requestPath: "/dummy_signature/rs:fill:300:200/plain/https:/example.com/path/to/image.jpg", + expectedPath: "rs:fill:300:200/plain/https://example.com/path/to/image.jpg", + expectedSig: "dummy_signature", + expectedError: false, + }, + { + name: "NoSignatureSeparator", + requestPath: "/invalid_path_without_slash", + expectedPath: "", + expectedSig: "", + expectedError: true, + }, + { + name: "EmptyPath", + requestPath: "/", + expectedPath: "", + expectedSig: "", + expectedError: true, + }, + { + name: "OnlySignature", + requestPath: "/signature_only", + expectedPath: "", + expectedSig: "", + expectedError: true, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + config := &Config{ + PathPrefix: tc.pathPrefix, + } + + req := s.createRequest(tc.requestPath) + path, signature, err := splitPathSignature(req, config) + + if tc.expectedError { + var ierr *ierrors.Error + + s.Require().Error(err) + s.Require().ErrorAs(err, &ierr) + s.Require().Equal(categoryPathParsing, ierr.Category()) + + return + } + + s.Require().NoError(err) + s.Require().Equal(tc.expectedPath, path) + s.Require().Equal(tc.expectedSig, signature) + }) + } +} + +func (s *PathTestSuite) TestRedenormalizePathHTTPProtocol() { + testCases := []struct { + name string + input string + expected string + }{ + { + name: "HTTP", + input: "/plain/http:/example.com/image.jpg", + expected: "/plain/http://example.com/image.jpg", + }, + { + name: "HTTPS", + input: "/plain/https:/example.com/image.jpg", + expected: "/plain/https://example.com/image.jpg", + }, + { + name: "Local", + input: "/plain/local:/image.jpg", + expected: "/plain/local:///image.jpg", + }, + { + name: "NormalizedPath", + input: "/plain/http://example.com/image.jpg", + expected: "/plain/http://example.com/image.jpg", + }, + { + name: "ProtocolMissing", + input: "/rs:fill:300:200/plain/example.com/image.jpg", + expected: "/rs:fill:300:200/plain/example.com/image.jpg", + }, + { + name: "EmptyString", + input: "", + expected: "", + }, + { + name: "SingleSlash", + input: "/", + expected: "/", + }, + { + name: "NoPlainPrefix", + input: "/http:/example.com/image.jpg", + expected: "/http:/example.com/image.jpg", + }, + { + name: "NoProtocol", + input: "/plain/example.com/image.jpg", + expected: "/plain/example.com/image.jpg", + }, + { + name: "EndsWithProtocol", + input: "/plain/http:", + expected: "/plain/http:", + }, + { + name: "OnlyProtocol", + input: "/plain/http:/test", + expected: "/plain/http://test", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + result := redenormalizePath(tc.input) + s.Equal(tc.expected, result) + }) + } +} diff --git a/handlers/processing/request.go b/handlers/processing/request.go new file mode 100644 index 00000000..79e43052 --- /dev/null +++ b/handlers/processing/request.go @@ -0,0 +1,140 @@ +package processing + +import ( + "context" + "errors" + "net/http" + + "github.com/imgproxy/imgproxy/v3/headerwriter" + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/imagefetcher" + "github.com/imgproxy/imgproxy/v3/imagetype" + "github.com/imgproxy/imgproxy/v3/monitoring" + "github.com/imgproxy/imgproxy/v3/monitoring/stats" + "github.com/imgproxy/imgproxy/v3/options" + "github.com/imgproxy/imgproxy/v3/semaphores" + "github.com/imgproxy/imgproxy/v3/server" + "github.com/imgproxy/imgproxy/v3/vips" +) + +// request holds the parameters and state for a single request request +type request struct { + handler *Handler + imageRequest *http.Request + reqID string + rw http.ResponseWriter + config *Config + po *options.ProcessingOptions + imageURL string + monitoringMeta monitoring.Meta + semaphores *semaphores.Semaphores + hwr *headerwriter.Request +} + +// execute handles the actual processing logic +func (r *request) execute(ctx context.Context) error { + // Check if we can save the resulting image + canSave := vips.SupportsSave(r.po.Format) || + r.po.Format == imagetype.Unknown || + r.po.Format == imagetype.SVG + + if !canSave { + return newCantSaveError(r.po.Format) + } + + // Acquire queue semaphore (if enabled) + releaseQueueSem, err := r.semaphores.AcquireQueue() + if err != nil { + return err + } + defer releaseQueueSem() + + // Acquire processing semaphore + releaseProcessingSem, err := r.acquireProcessingSem(ctx) + if err != nil { + return err + } + defer releaseProcessingSem() + + // Deal with processing image counter + stats.IncImagesInProgress() + defer stats.DecImagesInProgress() + + // Response status code is OK by default + statusCode := http.StatusOK + + // Request headers + imgRequestHeaders := r.makeImageRequestHeaders() + + // create download options + do := r.makeDownloadOptions(ctx, imgRequestHeaders) + + // Fetch image actual + originData, originHeaders, err := r.fetchImage(ctx, do) + if err == nil { + defer originData.Close() // if any originData has been opened, we need to close it + } + + // Check that image detection didn't take too long + if terr := server.CheckTimeout(ctx); terr != nil { + return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } + + // Respond with NotModified if image was not modified + var nmErr imagefetcher.NotModifiedError + + if errors.As(err, &nmErr) { + r.hwr.SetOriginHeaders(nmErr.Headers()) + + return r.respondWithNotModified() + } + + // Prepare to write image response headers + r.hwr.SetOriginHeaders(originHeaders) + + // If error is not related to NotModified, respond with fallback image and replace image data + if err != nil { + originData, statusCode, err = r.handleDownloadError(ctx, err) + if err != nil { + return err + } + } + + // Check if image supports load from origin format + if !vips.SupportsLoad(originData.Format()) { + return newCantLoadError(originData.Format()) + } + + // Actually process the image + result, err := r.processImage(ctx, originData) + + // Let's close resulting image data only if it differs from the source image data + if result != nil && result.OutData != nil && result.OutData != originData { + defer result.OutData.Close() + } + + // First, check if the processing error wasn't caused by an image data error + if derr := originData.Error(); derr != nil { + return ierrors.Wrap(derr, 0, ierrors.WithCategory(categoryDownload)) + } + + // If it wasn't, than it was a processing error + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing)) + } + + // Write debug headers. It seems unlogical to move they to headerwriter since they're + // not used anywhere else. + err = r.writeDebugHeaders(result, originData) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) + } + + // Responde with actual image + err = r.respondWithImage(statusCode, result.OutData) + if err != nil { + return err + } + + return nil +} diff --git a/handlers/processing/request_methods.go b/handlers/processing/request_methods.go new file mode 100644 index 00000000..219ad41c --- /dev/null +++ b/handlers/processing/request_methods.go @@ -0,0 +1,278 @@ +package processing + +import ( + "context" + "io" + "net/http" + "strconv" + + "github.com/imgproxy/imgproxy/v3/cookies" + "github.com/imgproxy/imgproxy/v3/errorreport" + "github.com/imgproxy/imgproxy/v3/httpheaders" + "github.com/imgproxy/imgproxy/v3/ierrors" + "github.com/imgproxy/imgproxy/v3/imagedata" + "github.com/imgproxy/imgproxy/v3/monitoring" + "github.com/imgproxy/imgproxy/v3/options" + "github.com/imgproxy/imgproxy/v3/processing" + "github.com/imgproxy/imgproxy/v3/server" + log "github.com/sirupsen/logrus" +) + +// makeImageRequestHeaders creates headers for the image request +func (r *request) makeImageRequestHeaders() http.Header { + h := make(http.Header) + + // If ETag is enabled, we forward If-None-Match header + if r.config.ETagEnabled { + h.Set(httpheaders.IfNoneMatch, r.imageRequest.Header.Get(httpheaders.IfNoneMatch)) + } + + // If LastModified is enabled, we forward If-Modified-Since header + if r.config.LastModifiedEnabled { + h.Set(httpheaders.IfModifiedSince, r.imageRequest.Header.Get(httpheaders.IfModifiedSince)) + } + + return h +} + +// acquireProcessingSem acquires the processing semaphore +func (r *request) acquireProcessingSem(ctx context.Context) (context.CancelFunc, error) { + defer monitoring.StartQueueSegment(ctx)() + + fn, err := r.semaphores.AcquireProcessing(ctx) + if err != nil { + // We don't actually need to check timeout here, + // but it's an easy way to check if this is an actual timeout + // or the request was canceled + if terr := server.CheckTimeout(ctx); terr != nil { + return nil, ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) + } + + // We should never reach this line as err could be only ctx.Err() + // and we've already checked for it. But beter safe than sorry + return nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue)) + } + + return fn, nil +} + +// makeDownloadOptions creates a new default download options +func (r *request) makeDownloadOptions(ctx context.Context, h http.Header) imagedata.DownloadOptions { + downloadFinished := monitoring.StartDownloadingSegment(ctx, r.monitoringMeta.Filter( + monitoring.MetaSourceImageURL, + monitoring.MetaSourceImageOrigin, + )) + + return imagedata.DownloadOptions{ + Header: h, + MaxSrcFileSize: r.po.SecurityOptions.MaxSrcFileSize, + DownloadFinished: downloadFinished, + } +} + +// fetchImage downloads the source image asynchronously +func (r *request) fetchImage(ctx context.Context, do imagedata.DownloadOptions) (imagedata.ImageData, http.Header, error) { + var err error + + if r.config.CookiePassthrough { + do.CookieJar, err = cookies.JarFromRequest(r.imageRequest) + if err != nil { + return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload)) + } + } + + return imagedata.DownloadAsync(ctx, r.imageURL, "source image", do) +} + +// handleDownloadError replaces the image data with fallback image if needed +func (r *request) handleDownloadError( + ctx context.Context, + originalErr error, +) (imagedata.ImageData, int, error) { + err := r.wrapDownloadingErr(originalErr) + + // If there is no fallback image configured, just return the error + data, headers := r.getFallbackImage(ctx, r.po) + if data == nil { + return nil, 0, err + } + + // Just send error + monitoring.SendError(ctx, categoryDownload, err) + + // We didn't return, so we have to report error + if err.ShouldReport() { + errorreport.Report(err, r.imageRequest) + } + + log. + WithField("request_id", r.reqID). + Warningf("Could not load image %s. Using fallback image. %s", r.imageURL, err.Error()) + + var statusCode int + + // Set status code if needed + if r.config.FallbackImageHTTPCode > 0 { + statusCode = r.config.FallbackImageHTTPCode + } else { + statusCode = err.StatusCode() + } + + // Fallback image should have exact FallbackImageTTL lifetime + headers.Del(httpheaders.Expires) + headers.Del(httpheaders.LastModified) + + r.hwr.SetOriginHeaders(headers) + r.hwr.SetIsFallbackImage() + + return data, statusCode, nil +} + +// getFallbackImage returns fallback image if any +func (r *request) getFallbackImage( + ctx context.Context, + po *options.ProcessingOptions, +) (imagedata.ImageData, http.Header) { + if r.handler.fallbackImage == nil { + return nil, nil + } + + data, h, err := r.handler.fallbackImage.Get(ctx, po) + if err != nil { + log.Warning(err.Error()) + + if ierr := r.wrapDownloadingErr(err); ierr.ShouldReport() { + errorreport.Report(ierr, r.imageRequest) + } + + return nil, nil + } + + return data, h +} + +// processImage calls actual image processing +func (r *request) processImage(ctx context.Context, originData imagedata.ImageData) (*processing.Result, error) { + defer monitoring.StartProcessingSegment(ctx, r.monitoringMeta.Filter(monitoring.MetaProcessingOptions))() + return processing.ProcessImage(ctx, originData, r.po) +} + +// writeDebugHeaders writes debug headers (X-Origin-*, X-Result-*) to the response +func (r *request) writeDebugHeaders(result *processing.Result, originData imagedata.ImageData) error { + if !r.config.EnableDebugHeaders { + return nil + } + + if result != nil { + r.rw.Header().Set(httpheaders.XOriginWidth, strconv.Itoa(result.OriginWidth)) + r.rw.Header().Set(httpheaders.XOriginHeight, strconv.Itoa(result.OriginHeight)) + r.rw.Header().Set(httpheaders.XResultWidth, strconv.Itoa(result.ResultWidth)) + r.rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight)) + } + + // Try to read origin image size + size, err := originData.Size() + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) + } + + r.rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size)) + + return nil +} + +// respondWithNotModified writes not-modified response +func (r *request) respondWithNotModified() error { + r.hwr.SetExpires(r.po.Expires) + r.hwr.SetVary() + + if r.config.LastModifiedEnabled { + r.hwr.Passthrough(httpheaders.LastModified) + } + + if r.config.ETagEnabled { + r.hwr.Passthrough(httpheaders.Etag) + } + + r.hwr.Write(r.rw) + + r.rw.WriteHeader(http.StatusNotModified) + + server.LogResponse( + r.reqID, r.imageRequest, http.StatusNotModified, nil, + log.Fields{ + "image_url": r.imageURL, + "processing_options": r.po, + }, + ) + + return nil +} + +func (r *request) respondWithImage(statusCode int, resultData imagedata.ImageData) error { + // We read the size of the image data here, so we can set Content-Length header. + // This indireclty ensures that the image data is fully read from the source, no + // errors happened. + resultSize, err := resultData.Size() + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) + } + + r.hwr.SetContentType(resultData.Format().Mime()) + r.hwr.SetContentLength(resultSize) + r.hwr.SetContentDisposition( + r.imageURL, + r.po.Filename, + resultData.Format().Ext(), + "", + r.po.ReturnAttachment, + ) + r.hwr.SetExpires(r.po.Expires) + r.hwr.SetVary() + r.hwr.SetCanonical(r.imageURL) + + if r.config.LastModifiedEnabled { + r.hwr.Passthrough(httpheaders.LastModified) + } + + if r.config.ETagEnabled { + r.hwr.Passthrough(httpheaders.Etag) + } + + r.hwr.Write(r.rw) + + r.rw.WriteHeader(statusCode) + + _, err = io.Copy(r.rw, resultData.Reader()) + + var ierr *ierrors.Error + if err != nil { + ierr = newResponseWriteError(err) + + if r.config.ReportIOErrors { + return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true)) + } + } + + server.LogResponse( + r.reqID, r.imageRequest, statusCode, ierr, + log.Fields{ + "image_url": r.imageURL, + "processing_options": r.po, + }, + ) + + return nil +} + +// wrapDownloadingErr wraps original error to download error +func (r *request) wrapDownloadingErr(originalErr error) *ierrors.Error { + err := ierrors.Wrap(originalErr, 0, ierrors.WithCategory(categoryDownload)) + + // we report this error only if enabled + if r.config.ReportDownloadingErrors { + err = ierrors.Wrap(err, 0, ierrors.WithShouldReport(true)) + } + + return err +} diff --git a/handlers/stream/config.go b/handlers/stream/config.go index 510aeafe..4ac0be5f 100644 --- a/handlers/stream/config.go +++ b/handlers/stream/config.go @@ -39,7 +39,7 @@ func NewDefaultConfig() *Config { } // LoadFromEnv loads config variables from environment -func (c *Config) LoadFromEnv() (*Config, error) { +func LoadFromEnv(c *Config) (*Config, error) { c.CookiePassthrough = config.CookiePassthrough return c, nil } diff --git a/handlers/stream/handler.go b/handlers/stream/handler.go index 22c33c21..4d1f955c 100644 --- a/handlers/stream/handler.go +++ b/handlers/stream/handler.go @@ -48,6 +48,7 @@ type request struct { reqID string po *options.ProcessingOptions rw http.ResponseWriter + hw *headerwriter.Request } // New creates new handler object @@ -79,6 +80,7 @@ func (s *Handler) Execute( reqID: reqID, po: po, rw: rw, + hw: s.hw.NewRequest(), } return stream.execute(ctx) @@ -116,18 +118,17 @@ func (s *request) execute(ctx context.Context) error { } // Output streaming response headers - hw := s.handler.hw.NewRequest(res.Header, s.imageURL) - - hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was - hw.SetContentLength(int(res.ContentLength)) - hw.SetCanonical() - hw.SetExpires(s.po.Expires) + s.hw.SetOriginHeaders(res.Header) + s.hw.Passthrough(s.handler.config.PassthroughResponseHeaders...) // NOTE: priority? This is lowest as it was + s.hw.SetContentLength(int(res.ContentLength)) + s.hw.SetCanonical(s.imageURL) + s.hw.SetExpires(s.po.Expires) // Set the Content-Disposition header - s.setContentDisposition(r.URL().Path, res, hw) + s.setContentDisposition(r.URL().Path, res, s.hw) // Write headers from writer - hw.Write(s.rw) + s.hw.Write(s.rw) // Copy the status code from the original response s.rw.WriteHeader(res.StatusCode) diff --git a/headerwriter/config.go b/headerwriter/config.go index 731fdc1b..c748bb90 100644 --- a/headerwriter/config.go +++ b/headerwriter/config.go @@ -29,7 +29,7 @@ func NewDefaultConfig() *Config { } // LoadFromEnv overrides configuration variables from environment -func (c *Config) LoadFromEnv() (*Config, error) { +func LoadFromEnv(c *Config) (*Config, error) { c.SetCanonicalHeader = config.SetCanonicalHeader c.DefaultTTL = config.TTL c.FallbackImageTTL = config.FallbackImageTTL diff --git a/headerwriter/writer.go b/headerwriter/writer.go index c4f6b5b2..62a27d6f 100644 --- a/headerwriter/writer.go +++ b/headerwriter/writer.go @@ -19,11 +19,10 @@ type Writer struct { // Request is a private struct that builds HTTP response headers for a specific request. type Request struct { - writer *Writer - originalResponseHeaders http.Header // Original response headers - result http.Header // Headers to be written to the response - maxAge int // Current max age for Cache-Control header - url string // URL of the request, used for canonical header + writer *Writer + originHeaders http.Header // Original response headers + result http.Header // Headers to be written to the response + maxAge int // Current max age for Cache-Control header } // New creates a new header writer factory with the provided config. @@ -51,16 +50,20 @@ func New(config *Config) (*Writer, error) { } // NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL. -func (w *Writer) NewRequest(originalResponseHeaders http.Header, url string) *Request { +func (w *Writer) NewRequest() *Request { return &Request{ - writer: w, - originalResponseHeaders: originalResponseHeaders, - url: url, - result: make(http.Header), - maxAge: -1, + writer: w, + result: make(http.Header), + maxAge: -1, + originHeaders: make(http.Header), } } +// SetOriginHeaders sets the origin headers for the request. +func (r *Request) SetOriginHeaders(h http.Header) { + r.originHeaders = h +} + // SetIsFallbackImage sets the Fallback-Image header to // indicate that the fallback image was used. func (r *Request) SetIsFallbackImage() { @@ -114,7 +117,7 @@ func (r *Request) SetContentDisposition(originURL, filename, ext, contentType st // Passthrough copies specified headers from the original response headers to the response headers. func (r *Request) Passthrough(only ...string) { - httpheaders.Copy(r.originalResponseHeaders, r.result, only) + httpheaders.Copy(r.originHeaders, r.result, only) } // CopyFrom copies specified headers from the headers object. Please note that @@ -139,13 +142,13 @@ func (r *Request) SetContentType(mime string) { // writeCanonical sets the Link header with the canonical URL. // It is mandatory for any response if enabled in the configuration. -func (r *Request) SetCanonical() { +func (r *Request) SetCanonical(url string) { if !r.writer.config.SetCanonicalHeader { return } - if strings.HasPrefix(r.url, "https://") || strings.HasPrefix(r.url, "http://") { - value := fmt.Sprintf(`<%s>; rel="canonical"`, r.url) + if strings.HasPrefix(url, "https://") || strings.HasPrefix(url, "http://") { + value := fmt.Sprintf(`<%s>; rel="canonical"`, url) r.result.Set(httpheaders.Link, value) } } @@ -172,12 +175,12 @@ func (r *Request) setCacheControlPassthrough() bool { return false } - if val := r.originalResponseHeaders.Get(httpheaders.CacheControl); val != "" { + if val := r.originHeaders.Get(httpheaders.CacheControl); val != "" { r.result.Set(httpheaders.CacheControl, val) return true } - if val := r.originalResponseHeaders.Get(httpheaders.Expires); val != "" { + if val := r.originHeaders.Get(httpheaders.Expires); val != "" { if t, err := time.Parse(http.TimeFormat, val); err == nil { maxAge := max(0, int(time.Until(t).Seconds())) return r.setCacheControl(maxAge) diff --git a/headerwriter/writer_test.go b/headerwriter/writer_test.go index 8bac68c3..9893d2ce 100644 --- a/headerwriter/writer_test.go +++ b/headerwriter/writer_test.go @@ -19,7 +19,6 @@ type HeaderWriterSuite struct { type writerTestCase struct { name string - url string req http.Header res http.Header config Config @@ -94,7 +93,6 @@ func (s *HeaderWriterSuite) TestHeaderCases() { { name: "Canonical_ValidURL", req: http.Header{}, - url: "https://example.com/image.jpg", res: http.Header{ httpheaders.Link: []string{"; rel=\"canonical\""}, httpheaders.CacheControl: []string{"max-age=3600, public"}, @@ -105,12 +103,11 @@ func (s *HeaderWriterSuite) TestHeaderCases() { DefaultTTL: 3600, }, fn: func(w *Request) { - w.SetCanonical() + w.SetCanonical("https://example.com/image.jpg") }, }, { name: "Canonical_InvalidURL", - url: "ftp://example.com/image.jpg", req: http.Header{}, res: http.Header{ httpheaders.CacheControl: []string{"max-age=3600, public"}, @@ -124,7 +121,6 @@ func (s *HeaderWriterSuite) TestHeaderCases() { { name: "WriteCanonical_Disabled", req: http.Header{}, - url: "https://example.com/image.jpg", res: http.Header{ httpheaders.CacheControl: []string{"max-age=3600, public"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, @@ -134,7 +130,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { DefaultTTL: 3600, }, fn: func(w *Request) { - w.SetCanonical() + w.SetCanonical("https://example.com/image.jpg") }, }, { @@ -305,7 +301,8 @@ func (s *HeaderWriterSuite) TestHeaderCases() { factory, err := New(&tc.config) s.Require().NoError(err) - writer := factory.NewRequest(tc.req, tc.url) + writer := factory.NewRequest() + writer.SetOriginHeaders(tc.req) if tc.fn != nil { tc.fn(writer) diff --git a/imagedata/download.go b/imagedata/download.go index f61c8d51..528cd7ff 100644 --- a/imagedata/download.go +++ b/imagedata/download.go @@ -40,7 +40,7 @@ func initDownloading() error { return err } - c, err := imagefetcher.NewDefaultConfig().LoadFromEnv() + c, err := imagefetcher.LoadFromEnv(imagefetcher.NewDefaultConfig()) if err != nil { return ierrors.Wrap(err, 0, ierrors.WithPrefix("configuration error")) } diff --git a/imagedata/image_data.go b/imagedata/image_data.go index b40a1f5e..4a663c2c 100644 --- a/imagedata/image_data.go +++ b/imagedata/image_data.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "net/http" "sync" "github.com/imgproxy/imgproxy/v3/asyncbuffer" @@ -14,9 +13,7 @@ import ( ) var ( - Watermark ImageData - FallbackImage ImageData - FallbackImageHeaders http.Header // Headers for the fallback image + Watermark ImageData ) // ImageData represents the data of an image that can be read from a source. @@ -139,10 +136,6 @@ func Init() error { return err } - if err := loadFallbackImage(); err != nil { - return err - } - return nil } @@ -178,30 +171,3 @@ func loadWatermark() error { return nil } - -func loadFallbackImage() (err error) { - switch { - case len(config.FallbackImageData) > 0: - FallbackImage, err = NewFromBase64(config.FallbackImageData) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't load fallback image from Base64")) - } - - case len(config.FallbackImagePath) > 0: - FallbackImage, err = NewFromPath(config.FallbackImagePath) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't read fallback image from file")) - } - - case len(config.FallbackImageURL) > 0: - FallbackImage, FallbackImageHeaders, err = DownloadSync(context.Background(), config.FallbackImageURL, "fallback image", DefaultDownloadOptions()) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't download from URL")) - } - - default: - FallbackImage = nil - } - - return err -} diff --git a/imagefetcher/config.go b/imagefetcher/config.go index 7939c208..c5b1fad0 100644 --- a/imagefetcher/config.go +++ b/imagefetcher/config.go @@ -28,7 +28,7 @@ func NewDefaultConfig() *Config { } // LoadFromEnv loads config variables from env -func (c *Config) LoadFromEnv() (*Config, error) { +func LoadFromEnv(c *Config) (*Config, error) { c.UserAgent = config.UserAgent c.DownloadTimeout = time.Duration(config.DownloadTimeout) * time.Second c.MaxRedirects = config.MaxRedirects diff --git a/main.go b/main.go index b2c04a3f..e1e476f1 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "net/http" "os" "os/signal" "syscall" @@ -12,11 +13,16 @@ import ( log "github.com/sirupsen/logrus" "go.uber.org/automaxprocs/maxprocs" + "github.com/imgproxy/imgproxy/v3/auximageprovider" "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config/loadenv" "github.com/imgproxy/imgproxy/v3/errorreport" "github.com/imgproxy/imgproxy/v3/gliblog" "github.com/imgproxy/imgproxy/v3/handlers" + processingHandler "github.com/imgproxy/imgproxy/v3/handlers/processing" + "github.com/imgproxy/imgproxy/v3/handlers/stream" + "github.com/imgproxy/imgproxy/v3/headerwriter" + "github.com/imgproxy/imgproxy/v3/ierrors" "github.com/imgproxy/imgproxy/v3/imagedata" "github.com/imgproxy/imgproxy/v3/logger" "github.com/imgproxy/imgproxy/v3/memory" @@ -24,16 +30,78 @@ import ( "github.com/imgproxy/imgproxy/v3/monitoring/prometheus" "github.com/imgproxy/imgproxy/v3/options" "github.com/imgproxy/imgproxy/v3/processing" + "github.com/imgproxy/imgproxy/v3/semaphores" "github.com/imgproxy/imgproxy/v3/server" "github.com/imgproxy/imgproxy/v3/version" "github.com/imgproxy/imgproxy/v3/vips" ) const ( - faviconPath = "/favicon.ico" - healthPath = "/health" + faviconPath = "/favicon.ico" + healthPath = "/health" + categoryConfig = "(tmp)config" // NOTE: temporary category for reporting configration errors ) +func callHandleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error { + // NOTE: This is temporary, will be moved level up at once + hwc, err := headerwriter.LoadFromEnv(headerwriter.NewDefaultConfig()) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + hw, err := headerwriter.New(hwc) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + sc, err := stream.LoadFromEnv(stream.NewDefaultConfig()) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + stream, err := stream.New(sc, hw, imagedata.Fetcher) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + phc, err := processingHandler.LoadFromEnv(processingHandler.NewDefaultConfig()) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + semc, err := semaphores.LoadFromEnv(semaphores.NewDefaultConfig()) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + semaphores, err := semaphores.New(semc) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + fic := auximageprovider.NewDefaultStaticConfig() + fic, err = auximageprovider.LoadFallbackStaticConfigFromEnv(fic) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + fi, err := auximageprovider.NewStaticProvider( + r.Context(), + fic, + "fallback image", + ) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + h, err := processingHandler.New(stream, hw, semaphores, fi, phc) + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) + } + + return h.Execute(reqID, rw, r) +} + func buildRouter(r *server.Router) *server.Router { r.GET("/", handlers.LandingHandler) r.GET("", handlers.LandingHandler) @@ -80,8 +148,6 @@ func initialize() error { return err } - initProcessingHandler() - errorreport.Init() if err := vips.Init(); err != nil { @@ -137,7 +203,7 @@ func run(ctx context.Context) error { return err } - cfg, err := server.NewDefaultConfig().LoadFromEnv() + cfg, err := server.LoadFromEnv(server.NewDefaultConfig()) if err != nil { return err } diff --git a/processing_handler.go b/processing_handler.go deleted file mode 100644 index 643d1aec..00000000 --- a/processing_handler.go +++ /dev/null @@ -1,425 +0,0 @@ -package main - -import ( - "errors" - "io" - "net/http" - "net/url" - "strconv" - "strings" - - log "github.com/sirupsen/logrus" - "golang.org/x/sync/semaphore" - - "github.com/imgproxy/imgproxy/v3/config" - "github.com/imgproxy/imgproxy/v3/cookies" - "github.com/imgproxy/imgproxy/v3/errorreport" - "github.com/imgproxy/imgproxy/v3/handlers/stream" - "github.com/imgproxy/imgproxy/v3/headerwriter" - "github.com/imgproxy/imgproxy/v3/httpheaders" - "github.com/imgproxy/imgproxy/v3/ierrors" - "github.com/imgproxy/imgproxy/v3/imagedata" - "github.com/imgproxy/imgproxy/v3/imagefetcher" - "github.com/imgproxy/imgproxy/v3/imagetype" - "github.com/imgproxy/imgproxy/v3/monitoring" - "github.com/imgproxy/imgproxy/v3/monitoring/stats" - "github.com/imgproxy/imgproxy/v3/options" - "github.com/imgproxy/imgproxy/v3/processing" - "github.com/imgproxy/imgproxy/v3/security" - "github.com/imgproxy/imgproxy/v3/server" - "github.com/imgproxy/imgproxy/v3/vips" -) - -var ( - queueSem *semaphore.Weighted - processingSem *semaphore.Weighted -) - -func initProcessingHandler() { - if config.RequestsQueueSize > 0 { - queueSem = semaphore.NewWeighted(int64(config.RequestsQueueSize + config.Workers)) - } - - processingSem = semaphore.NewWeighted(int64(config.Workers)) -} - -// writeDebugHeaders writes debug headers (X-Origin-*, X-Result-*) to the response -func writeDebugHeaders(rw http.ResponseWriter, result *processing.Result, originData imagedata.ImageData) error { - if !config.EnableDebugHeaders { - return nil - } - - if result != nil { - rw.Header().Set(httpheaders.XOriginWidth, strconv.Itoa(result.OriginWidth)) - rw.Header().Set(httpheaders.XOriginHeight, strconv.Itoa(result.OriginHeight)) - rw.Header().Set(httpheaders.XResultWidth, strconv.Itoa(result.ResultWidth)) - rw.Header().Set(httpheaders.XResultHeight, strconv.Itoa(result.ResultHeight)) - } - - // Try to read origin image size - size, err := originData.Size() - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) - } - - rw.Header().Set(httpheaders.XOriginContentLength, strconv.Itoa(size)) - - return nil -} - -func respondWithImage( - reqID string, - r *http.Request, - rw http.ResponseWriter, - statusCode int, - resultData imagedata.ImageData, - po *options.ProcessingOptions, - originURL string, - hw *headerwriter.Request, -) error { - // We read the size of the image data here, so we can set Content-Length header. - // This indireclty ensures that the image data is fully read from the source, no - // errors happened. - resultSize, err := resultData.Size() - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) - } - - hw.SetContentType(resultData.Format().Mime()) - hw.SetContentLength(resultSize) - hw.SetContentDisposition( - originURL, - po.Filename, - resultData.Format().Ext(), - "", - po.ReturnAttachment, - ) - hw.SetExpires(po.Expires) - hw.SetVary() - hw.SetCanonical() - - if config.LastModifiedEnabled { - hw.Passthrough(httpheaders.LastModified) - } - - if config.ETagEnabled { - hw.Passthrough(httpheaders.Etag) - } - - hw.Write(rw) - - rw.WriteHeader(statusCode) - - _, err = io.Copy(rw, resultData.Reader()) - - var ierr *ierrors.Error - if err != nil { - ierr = newResponseWriteError(err) - - if config.ReportIOErrors { - return ierrors.Wrap(ierr, 0, ierrors.WithCategory(categoryIO), ierrors.WithShouldReport(true)) - } - } - - server.LogResponse( - reqID, r, statusCode, ierr, - log.Fields{ - "image_url": originURL, - "processing_options": po, - }, - ) - - return nil -} - -func respondWithNotModified(reqID string, r *http.Request, rw http.ResponseWriter, po *options.ProcessingOptions, originURL string, hw *headerwriter.Request) { - hw.SetExpires(po.Expires) - hw.SetVary() - - if config.ETagEnabled { - hw.Passthrough(httpheaders.Etag) - } - - hw.Write(rw) - - rw.WriteHeader(http.StatusNotModified) - - server.LogResponse( - reqID, r, http.StatusNotModified, nil, - log.Fields{ - "image_url": originURL, - "processing_options": po, - }, - ) -} - -func callHandleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) error { - // NOTE: This is temporary, will be moved level up at once - hwc, err := headerwriter.NewDefaultConfig().LoadFromEnv() - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) - } - - hw, err := headerwriter.New(hwc) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) - } - - sc, err := stream.NewDefaultConfig().LoadFromEnv() - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) - } - - stream, err := stream.New(sc, hw, imagedata.Fetcher) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryConfig)) - } - - return handleProcessing(reqID, rw, r, hw, stream) -} - -func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request, hw *headerwriter.Writer, stream *stream.Handler) error { - stats.IncRequestsInProgress() - defer stats.DecRequestsInProgress() - - ctx := r.Context() - - path := r.RequestURI - if queryStart := strings.IndexByte(path, '?'); queryStart >= 0 { - path = path[:queryStart] - } - - if len(config.PathPrefix) > 0 { - path = strings.TrimPrefix(path, config.PathPrefix) - } - - path = strings.TrimPrefix(path, "/") - signature := "" - - if signatureEnd := strings.IndexByte(path, '/'); signatureEnd > 0 { - signature = path[:signatureEnd] - path = path[signatureEnd:] - } else { - return ierrors.Wrap( - newInvalidURLErrorf(http.StatusNotFound, "Invalid path: %s", path), 0, - ierrors.WithCategory(categoryPathParsing), - ) - } - - path = fixPath(path) - - if err := security.VerifySignature(signature, path); err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) - } - - po, imageURL, err := options.ParsePath(path, r.Header) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryPathParsing)) - } - - var imageOrigin any - if u, uerr := url.Parse(imageURL); uerr == nil { - imageOrigin = u.Scheme + "://" + u.Host - } - - errorreport.SetMetadata(r, "Source Image URL", imageURL) - errorreport.SetMetadata(r, "Source Image Origin", imageOrigin) - errorreport.SetMetadata(r, "Processing Options", po) - - monitoringMeta := monitoring.Meta{ - monitoring.MetaSourceImageURL: imageURL, - monitoring.MetaSourceImageOrigin: imageOrigin, - monitoring.MetaProcessingOptions: po.Diff().Flatten(), - } - - monitoring.SetMetadata(ctx, monitoringMeta) - - err = security.VerifySourceURL(imageURL) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categorySecurity)) - } - - if po.Raw { - return stream.Execute(ctx, r, imageURL, reqID, po, rw) - } - - // SVG is a special case. Though saving to svg is not supported, SVG->SVG is. - if !vips.SupportsSave(po.Format) && po.Format != imagetype.Unknown && po.Format != imagetype.SVG { - return ierrors.Wrap(newInvalidURLErrorf( - http.StatusUnprocessableEntity, - "Resulting image format is not supported: %s", po.Format, - ), 0, ierrors.WithCategory(categoryPathParsing)) - } - - imgRequestHeader := make(http.Header) - - // If ETag is enabled, we forward If-None-Match header - if config.ETagEnabled { - imgRequestHeader.Set(httpheaders.IfNoneMatch, r.Header.Get(httpheaders.IfNoneMatch)) - } - - // If LastModified is enabled, we forward If-Modified-Since header - if config.LastModifiedEnabled { - imgRequestHeader.Set(httpheaders.IfModifiedSince, r.Header.Get(httpheaders.IfModifiedSince)) - } - - if queueSem != nil { - acquired := queueSem.TryAcquire(1) - if !acquired { - panic(newTooManyRequestsError()) - } - defer queueSem.Release(1) - } - - // The heavy part starts here, so we need to restrict worker number - err = func() error { - defer monitoring.StartQueueSegment(ctx)() - - err = processingSem.Acquire(ctx, 1) - if err != nil { - // We don't actually need to check timeout here, - // but it's an easy way to check if this is an actual timeout - // or the request was canceled - if terr := server.CheckTimeout(ctx); terr != nil { - return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) - } - - // We should never reach this line as err could be only ctx.Err() - // and we've already checked for it. But beter safe than sorry - - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryQueue)) - } - - return nil - }() - if err != nil { - return err - } - defer processingSem.Release(1) - - stats.IncImagesInProgress() - defer stats.DecImagesInProgress() - - statusCode := http.StatusOK - - originData, originHeaders, err := func() (imagedata.ImageData, http.Header, error) { - downloadFinished := monitoring.StartDownloadingSegment(ctx, monitoringMeta.Filter( - monitoring.MetaSourceImageURL, - monitoring.MetaSourceImageOrigin, - )) - - downloadOpts := imagedata.DownloadOptions{ - Header: imgRequestHeader, - CookieJar: nil, - MaxSrcFileSize: po.SecurityOptions.MaxSrcFileSize, - DownloadFinished: downloadFinished, - } - - if config.CookiePassthrough { - downloadOpts.CookieJar, err = cookies.JarFromRequest(r) - if err != nil { - return nil, nil, ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload)) - } - } - - return imagedata.DownloadAsync(ctx, imageURL, "source image", downloadOpts) - }() - - // Close originData if no error occurred - if err == nil { - defer originData.Close() - } - - // Check that image detection didn't take too long - if terr := server.CheckTimeout(ctx); terr != nil { - return ierrors.Wrap(terr, 0, ierrors.WithCategory(categoryTimeout)) - } - - var nmErr imagefetcher.NotModifiedError - - // Respond with NotModified if image was not modified - if errors.As(err, &nmErr) { - hwr := hw.NewRequest(nmErr.Headers(), imageURL) - - respondWithNotModified(reqID, r, rw, po, imageURL, hwr) - return nil - } - - // If error is not related to NotModified, respond with fallback image - if err != nil { - ierr := ierrors.Wrap(err, 0, ierrors.WithCategory(categoryDownload)) - if config.ReportDownloadingErrors { - ierr = ierrors.Wrap(ierr, 0, ierrors.WithShouldReport(true)) - } - - if imagedata.FallbackImage == nil { - return ierr - } - - // Just send error - monitoring.SendError(ctx, categoryDownload, ierr) - - // We didn't return, so we have to report error - if ierr.ShouldReport() { - errorreport.Report(ierr, r) - } - - log.WithField("request_id", reqID).Warningf("Could not load image %s. Using fallback image. %s", imageURL, ierr.Error()) - - if config.FallbackImageHTTPCode > 0 { - statusCode = config.FallbackImageHTTPCode - } else { - statusCode = ierr.StatusCode() - } - - originData = imagedata.FallbackImage - originHeaders = imagedata.FallbackImageHeaders.Clone() - - if config.FallbackImageTTL > 0 { - originHeaders.Set("Fallback-Image", "1") - } - } - - if !vips.SupportsLoad(originData.Format()) { - return ierrors.Wrap(newInvalidURLErrorf( - http.StatusUnprocessableEntity, - "Source image format is not supported: %s", originData.Format(), - ), 0, ierrors.WithCategory(categoryProcessing)) - } - - result, err := func() (*processing.Result, error) { - defer monitoring.StartProcessingSegment(ctx, monitoringMeta.Filter(monitoring.MetaProcessingOptions))() - return processing.ProcessImage(ctx, originData, po) - }() - - // Let's close resulting image data only if it differs from the source image data - if result != nil && result.OutData != nil && result.OutData != originData { - defer result.OutData.Close() - } - - // First, check if the processing error wasn't caused by an image data error - if derr := originData.Error(); derr != nil { - return ierrors.Wrap(derr, 0, ierrors.WithCategory(categoryDownload)) - } - - // If it wasn't, than it was a processing error - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryProcessing)) - } - - hwr := hw.NewRequest(originHeaders, imageURL) - - // Write debug headers. It seems unlogical to move they to headerwriter since they're - // not used anywhere else. - err = writeDebugHeaders(rw, result, originData) - if err != nil { - return ierrors.Wrap(err, 0, ierrors.WithCategory(categoryImageDataSize)) - } - - err = respondWithImage(reqID, r, rw, statusCode, result.OutData, po, imageURL, hwr) - if err != nil { - return err - } - - return nil -} diff --git a/processing_handler_test.go b/processing_handler_test.go index 166c13a1..b2bd2f18 100644 --- a/processing_handler_test.go +++ b/processing_handler_test.go @@ -1,5 +1,10 @@ package main +// NOTE: this test is the integration test for the processing handler. We can't extract and +// move it to handlers package yet because it depends on the global routes, methods and +// initialization functions. Once those would we wrapped into structures, we'll be able to move this test +// to where it belongs. + import ( "fmt" "io" diff --git a/security/signature.go b/security/signature.go index 7733b24f..cc577100 100644 --- a/security/signature.go +++ b/security/signature.go @@ -4,6 +4,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" + "slices" "github.com/imgproxy/imgproxy/v3/config" ) @@ -13,10 +14,8 @@ func VerifySignature(signature, path string) error { return nil } - for _, s := range config.TrustedSignatures { - if s == signature { - return nil - } + if slices.Contains(config.TrustedSignatures, signature) { + return nil } messageMAC, err := base64.RawURLEncoding.DecodeString(signature) @@ -36,6 +35,13 @@ func VerifySignature(signature, path string) error { func signatureFor(str string, key, salt []byte, signatureSize int) []byte { mac := hmac.New(sha256.New, key) mac.Write(salt) + + // It's supposed that path starts with '/'. However, if and input path comes with the + // leading slash split, let's re-add it here. + if str[0] != '/' { + mac.Write([]byte{'/'}) + } + mac.Write([]byte(str)) expectedMAC := mac.Sum(nil) if signatureSize < 32 { diff --git a/security/signature_test.go b/security/signature_test.go index a0ef0436..7594c38f 100644 --- a/security/signature_test.go +++ b/security/signature_test.go @@ -20,19 +20,19 @@ func (s *SignatureTestSuite) SetupTest() { } func (s *SignatureTestSuite) TestVerifySignature() { - err := VerifySignature("dtLwhdnPPiu_epMl1LrzheLpvHas-4mwvY6L3Z8WwlY", "asd") + err := VerifySignature("oWaL7QoW5TsgbuiS9-5-DI8S3Ibbo1gdB2SteJh3a20", "asd") s.Require().NoError(err) } func (s *SignatureTestSuite) TestVerifySignatureTruncated() { config.SignatureSize = 8 - err := VerifySignature("dtLwhdnPPis", "asd") + err := VerifySignature("oWaL7QoW5Ts", "asd") s.Require().NoError(err) } func (s *SignatureTestSuite) TestVerifySignatureInvalid() { - err := VerifySignature("dtLwhdnPPis", "asd") + err := VerifySignature("oWaL7QoW5Ts", "asd") s.Require().Error(err) } @@ -40,10 +40,10 @@ func (s *SignatureTestSuite) TestVerifySignatureMultiplePairs() { config.Keys = append(config.Keys, []byte("test-key2")) config.Salts = append(config.Salts, []byte("test-salt2")) - err := VerifySignature("dtLwhdnPPiu_epMl1LrzheLpvHas-4mwvY6L3Z8WwlY", "asd") + err := VerifySignature("jYz1UZ7j1BCdSzH3pZhaYf0iuz0vusoOTdqJsUT6WXI", "asd") s.Require().NoError(err) - err = VerifySignature("jbDffNPt1-XBgDccsaE-XJB9lx8JIJqdeYIZKgOqZpg", "asd") + err = VerifySignature("oWaL7QoW5TsgbuiS9-5-DI8S3Ibbo1gdB2SteJh3a20", "asd") s.Require().NoError(err) err = VerifySignature("dtLwhdnPPis", "asd") diff --git a/semaphores/config.go b/semaphores/config.go new file mode 100644 index 00000000..18a299bc --- /dev/null +++ b/semaphores/config.go @@ -0,0 +1,43 @@ +package semaphores + +import ( + "fmt" + "runtime" + + "github.com/imgproxy/imgproxy/v3/config" +) + +// Config represents handler config +type Config struct { + RequestsQueueSize int // Request queue size + Workers int // Number of workers +} + +// NewDefaultConfig creates a new configuration with defaults +func NewDefaultConfig() *Config { + return &Config{ + RequestsQueueSize: 0, + Workers: runtime.GOMAXPROCS(0) * 2, + } +} + +// LoadFromEnv loads config from environment variables +func LoadFromEnv(c *Config) (*Config, error) { + c.RequestsQueueSize = config.RequestsQueueSize + c.Workers = config.Workers + + return c, nil +} + +// Validate checks configuration values +func (c *Config) Validate() error { + if c.RequestsQueueSize < 0 { + return fmt.Errorf("requests queue size should be greater than or equal 0, now - %d", c.RequestsQueueSize) + } + + if c.Workers <= 0 { + return fmt.Errorf("workers number should be greater than 0, now - %d", c.Workers) + } + + return nil +} diff --git a/semaphores/errors.go b/semaphores/errors.go new file mode 100644 index 00000000..22352954 --- /dev/null +++ b/semaphores/errors.go @@ -0,0 +1,21 @@ +package semaphores + +import ( + "net/http" + + "github.com/imgproxy/imgproxy/v3/ierrors" +) + +type TooManyRequestsError struct{} + +func newTooManyRequestsError() error { + return ierrors.Wrap( + TooManyRequestsError{}, + 1, + ierrors.WithStatusCode(http.StatusTooManyRequests), + ierrors.WithPublicMessage("Too many requests"), + ierrors.WithShouldReport(false), + ) +} + +func (e TooManyRequestsError) Error() string { return "Too many requests" } diff --git a/semaphores/semaphores.go b/semaphores/semaphores.go new file mode 100644 index 00000000..4e506bcd --- /dev/null +++ b/semaphores/semaphores.go @@ -0,0 +1,65 @@ +package semaphores + +import ( + "context" + + "github.com/imgproxy/imgproxy/v3/monitoring" + "golang.org/x/sync/semaphore" +) + +// Semaphores is a container for the queue and processing semaphores +type Semaphores struct { + // queueSize semaphore: limits the queueSize size + queueSize *semaphore.Weighted + + // processing semaphore: limits the number of concurrent image processings + processing *semaphore.Weighted +} + +// New creates new semaphores instance +func New(config *Config) (*Semaphores, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + var queue *semaphore.Weighted + + if config.RequestsQueueSize > 0 { + queue = semaphore.NewWeighted(int64(config.RequestsQueueSize + config.Workers)) + } + + processing := semaphore.NewWeighted(int64(config.Workers)) + + return &Semaphores{ + queueSize: queue, + processing: processing, + }, nil +} + +// AcquireQueue acquires the queue semaphore and returns release function and error. +// if queue semaphore is not configured, it returns a noop anonymous function to make +// semaphore usage transparent. +func (s *Semaphores) AcquireQueue() (context.CancelFunc, error) { + if s.queueSize == nil { + return func() {}, nil // return no-op cancel function if semaphore is disabled + } + + acquired := s.queueSize.TryAcquire(1) + if !acquired { + return nil, newTooManyRequestsError() + } + + return func() { s.queueSize.Release(1) }, nil +} + +// AcquireProcessing acquires the processing semaphore +func (s *Semaphores) AcquireProcessing(ctx context.Context) (context.CancelFunc, error) { + defer monitoring.StartQueueSegment(ctx)() + + err := s.processing.Acquire(ctx, 1) + if err != nil { + return nil, err + } + + return func() { s.processing.Release(1) }, nil +} diff --git a/semaphores/semaphores_test.go b/semaphores/semaphores_test.go new file mode 100644 index 00000000..a5eaef42 --- /dev/null +++ b/semaphores/semaphores_test.go @@ -0,0 +1,50 @@ +package semaphores + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSemaphoresQueueDisabled(t *testing.T) { + s, err := New(&Config{RequestsQueueSize: 0, Workers: 1}) + require.NoError(t, err) + + // Queue acquire should always work when disabled + release, err := s.AcquireQueue() + require.NoError(t, err) + release() // Should not panic + + procRelease, err := s.AcquireProcessing(t.Context()) + require.NoError(t, err) + procRelease() +} + +func TestSemaphoresQueueEnabled(t *testing.T) { + s, err := New(&Config{RequestsQueueSize: 1, Workers: 1}) + require.NoError(t, err) + + // Should be able to acquire up to queue size + workers + release1, err := s.AcquireQueue() + require.NoError(t, err) + + release2, err := s.AcquireQueue() + require.NoError(t, err) + + // Third should fail (exceeds capacity) + _, err = s.AcquireQueue() + require.Error(t, err) + + // Release and try again + release1() + release3, err := s.AcquireQueue() + require.NoError(t, err) + + release2() + release3() +} + +func TestSemaphoresInvalidConfig(t *testing.T) { + _, err := New(&Config{RequestsQueueSize: 0, Workers: 0}) + require.Error(t, err) +} diff --git a/server/config.go b/server/config.go index a9b9aaa5..6af4c80b 100644 --- a/server/config.go +++ b/server/config.go @@ -51,7 +51,7 @@ func NewDefaultConfig() *Config { } // LoadFromEnv overrides current values with environment variables -func (c *Config) LoadFromEnv() (*Config, error) { +func LoadFromEnv(c *Config) (*Config, error) { c.Network = config.Network c.Bind = config.Bind c.PathPrefix = config.PathPrefix