diff --git a/errors.go b/errors.go index 9fd45702..7d660ccf 100644 --- a/errors.go +++ b/errors.go @@ -17,7 +17,7 @@ const ( categoryDownload = "download" categoryProcessing = "processing" categoryIO = "IO" - categoryStreaming = "streaming" + categoryConfig = "config(tmp)" // NOTE: THIS IS TEMPORARY ) type ( diff --git a/handlers/stream/config.go b/handlers/stream/config.go index be93cbed..510aeafe 100644 --- a/handlers/stream/config.go +++ b/handlers/stream/config.go @@ -17,10 +17,10 @@ type Config struct { PassthroughResponseHeaders []string } -// NewConfigFromEnv creates a new Config instance from environment variables -func NewConfigFromEnv() *Config { +// NewDefaultConfig returns a new Config instance with default values. +func NewDefaultConfig() *Config { return &Config{ - CookiePassthrough: config.CookiePassthrough, + CookiePassthrough: false, PassthroughRequestHeaders: []string{ httpheaders.IfNoneMatch, httpheaders.IfModifiedSince, @@ -37,3 +37,14 @@ func NewConfigFromEnv() *Config { }, } } + +// LoadFromEnv loads config variables from environment +func (c *Config) LoadFromEnv() (*Config, error) { + c.CookiePassthrough = config.CookiePassthrough + return c, nil +} + +// Validate checks config for errors +func (c *Config) Validate() error { + return nil +} diff --git a/handlers/stream/handler.go b/handlers/stream/handler.go index ef8fcb8e..b4640fb9 100644 --- a/handlers/stream/handler.go +++ b/handlers/stream/handler.go @@ -35,9 +35,9 @@ var ( // Handler handles image passthrough requests, allowing images to be streamed directly type Handler struct { - fetcher *imagefetcher.Fetcher // Fetcher instance to handle image fetching - config *Config // Configuration for the streamer - hwConfig *headerwriter.Config // Configuration for header writing + config *Config // Configuration for the streamer + fetcher *imagefetcher.Fetcher // Fetcher instance to handle image fetching + hw *headerwriter.Writer // Configured HeaderWriter instance } // request holds the parameters and state for a single streaming request @@ -51,12 +51,16 @@ type request struct { } // New creates new handler object -func New(config *Config, hwConfig *headerwriter.Config, fetcher *imagefetcher.Fetcher) *Handler { - return &Handler{ - fetcher: fetcher, - config: config, - hwConfig: hwConfig, +func New(config *Config, hw *headerwriter.Writer, fetcher *imagefetcher.Fetcher) (*Handler, error) { + if err := config.Validate(); err != nil { + return nil, err } + + return &Handler{ + fetcher: fetcher, + config: config, + hw: hw, + }, nil } // Stream handles the image passthrough request, streaming the image directly to the response writer @@ -110,7 +114,8 @@ func (s *request) execute(ctx context.Context) error { } // Output streaming response headers - hw := headerwriter.New(s.handler.hwConfig, res.Header, s.imageURL) + hw := s.handler.hw.NewRequest(res.Header, s.imageURL) + hw.Passthrough(s.handler.config.PassthroughResponseHeaders) // NOTE: priority? This is lowest as it was hw.SetContentLength(int(res.ContentLength)) hw.SetCanonical() diff --git a/handlers/stream/handler_test.go b/handlers/stream/handler_test.go index 1762df68..adaf3b94 100644 --- a/handlers/stream/handler_test.go +++ b/handlers/stream/handler_test.go @@ -52,10 +52,20 @@ func (s *HandlerTestSuite) SetupTest() { tr, err := transport.NewTransport() s.Require().NoError(err) - fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) + fc := imagefetcher.NewDefaultConfig() + + fetcher, err := imagefetcher.NewFetcher(tr, fc) s.Require().NoError(err) - s.handler = New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) + cfg := NewDefaultConfig() + + hwc := headerwriter.NewDefaultConfig() + hw, err := headerwriter.New(hwc) + s.Require().NoError(err) + + h, err := New(cfg, hw, fetcher) + s.Require().NoError(err) + s.handler = h } func (s *HandlerTestSuite) readTestFile(name string) []byte { @@ -207,8 +217,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { oneMinuteDelta = float64(time.Minute) ) - // Set this explicitly for testing purposes - config.TTL = 4242 + defaultTTL := 4242 testCases := []testCase{ { @@ -248,7 +257,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { timestampOffset: nil, expectedStatusCode: 200, validate: func(t *testing.T, res *http.Response) { - s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second) + s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second) }, }, // When expires is set in processing options, but not present in the response @@ -320,17 +329,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { timestampOffset: nil, expectedStatusCode: 200, validate: func(t *testing.T, res *http.Response) { - s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second) + s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second) }, }, } for _, tc := range testCases { s.Run(tc.name, func() { - // Set config values for this test - config.CacheControlPassthrough = tc.cacheControlPassthrough - config.TTL = 4242 // Set consistent TTL for testing - data := s.readTestFile("test1.png") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -345,10 +350,21 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() { tr, err := transport.NewTransport() s.Require().NoError(err) - fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) + fc := imagefetcher.NewDefaultConfig() + + fetcher, err := imagefetcher.NewFetcher(tr, fc) s.Require().NoError(err) - handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) + cfg := NewDefaultConfig() + hwc := headerwriter.NewDefaultConfig() + hwc.CacheControlPassthrough = tc.cacheControlPassthrough + hwc.DefaultTTL = 4242 + + hw, err := headerwriter.New(hwc) + s.Require().NoError(err) + + handler, err := New(cfg, hw, fetcher) + s.Require().NoError(err) req := httptest.NewRequest("GET", "/", nil) rw := httptest.NewRecorder() @@ -424,20 +440,23 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() { // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service. func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { - // Enable cookie passthrough for this test - config.CookiePassthrough = true - defer func() { - config.CookiePassthrough = false // Reset after test - }() - // Create new handler with updated config tr, err := transport.NewTransport() s.Require().NoError(err) - fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) + fc := imagefetcher.NewDefaultConfig() + fetcher, err := imagefetcher.NewFetcher(tr, fc) s.Require().NoError(err) - handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) + cfg := NewDefaultConfig() + cfg.CookiePassthrough = true + + hwc := headerwriter.NewDefaultConfig() + hw, err := headerwriter.New(hwc) + s.Require().NoError(err) + + handler, err := New(cfg, hw, fetcher) + s.Require().NoError(err) data := s.readTestFile("test1.png") @@ -478,16 +497,24 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() { defer ts.Close() for _, sc := range []bool{true, false} { - config.SetCanonicalHeader = sc - // Create new handler with updated config tr, err := transport.NewTransport() s.Require().NoError(err) - fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) + fc := imagefetcher.NewDefaultConfig() + fetcher, err := imagefetcher.NewFetcher(tr, fc) s.Require().NoError(err) - handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) + cfg := NewDefaultConfig() + hwc := headerwriter.NewDefaultConfig() + + hwc.SetCanonicalHeader = sc + + hw, err := headerwriter.New(hwc) + s.Require().NoError(err) + + handler, err := New(cfg, hw, fetcher) + s.Require().NoError(err) req := httptest.NewRequest("GET", "/", nil) rw := httptest.NewRecorder() diff --git a/headerwriter/config.go b/headerwriter/config.go index 15ac9c0e..f2c6c7c7 100644 --- a/headerwriter/config.go +++ b/headerwriter/config.go @@ -1,6 +1,8 @@ package headerwriter import ( + "fmt" + "github.com/imgproxy/imgproxy/v3/config" ) @@ -15,20 +17,46 @@ type Config struct { SetVaryAccept bool // Whether to include Accept in Vary header } -// NewConfigFromEnv creates a new Config instance from the current configuration -func NewConfigFromEnv() *Config { +// NewDefaultConfig returns a new Config instance with default values. +func NewDefaultConfig() *Config { return &Config{ - SetCanonicalHeader: config.SetCanonicalHeader, - DefaultTTL: config.TTL, - FallbackImageTTL: config.FallbackImageTTL, - LastModifiedEnabled: config.LastModifiedEnabled, - CacheControlPassthrough: config.CacheControlPassthrough, - EnableClientHints: config.EnableClientHints, - SetVaryAccept: config.AutoWebp || - config.EnforceWebp || - config.AutoAvif || - config.EnforceAvif || - config.AutoJxl || - config.EnforceJxl, + SetCanonicalHeader: false, + DefaultTTL: 31536000, + FallbackImageTTL: 0, + LastModifiedEnabled: false, + CacheControlPassthrough: false, + EnableClientHints: false, + SetVaryAccept: false, } } + +// LoadFromEnv overrides configuration variables from environment +func (c *Config) LoadFromEnv() (*Config, error) { + c.SetCanonicalHeader = config.SetCanonicalHeader + c.DefaultTTL = config.TTL + c.FallbackImageTTL = config.FallbackImageTTL + c.LastModifiedEnabled = config.LastModifiedEnabled + c.CacheControlPassthrough = config.CacheControlPassthrough + c.EnableClientHints = config.EnableClientHints + c.SetVaryAccept = config.AutoWebp || + config.EnforceWebp || + config.AutoAvif || + config.EnforceAvif || + config.AutoJxl || + config.EnforceJxl + + return c, nil +} + +// Validate checks config for errors +func (c *Config) Validate() error { + if c.DefaultTTL < 0 { + return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL) + } + + if c.FallbackImageTTL < 0 { + return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL) + } + + return nil +} diff --git a/headerwriter/writer.go b/headerwriter/writer.go index fa7c9f82..1f4208a2 100644 --- a/headerwriter/writer.go +++ b/headerwriter/writer.go @@ -11,18 +11,27 @@ import ( "github.com/imgproxy/imgproxy/v3/httpheaders" ) -// Writer is a struct that builds HTTP response headers. +// Writer is a struct that creates header writer factories. type Writer struct { - config *Config + config *Config + varyValue string +} + +// writer is a private struct that builds HTTP response headers for a specific request. +type writer struct { + writer *Writer originalResponseHeaders http.Header // Original response headers result http.Header // Headers to be written to the response maxAge int // Current max age for Cache-Control header url string // URL of the request, used for canonical header - varyValue string // Vary header value } -// New creates a new HeaderBuilder instance with the provided origin headers and URL -func New(config *Config, originalResponseHeaders http.Header, url string) *Writer { +// New creates a new header writer factory with the provided config. +func New(config *Config) (*Writer, error) { + if err := config.Validate(); err != nil { + return nil, err + } + vary := make([]string, 0) if config.SetVaryAccept { @@ -36,31 +45,38 @@ func New(config *Config, originalResponseHeaders http.Header, url string) *Write varyValue := strings.Join(vary, ", ") return &Writer{ - config: config, + config: config, + varyValue: varyValue, + }, nil +} + +// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL. +func (w *Writer) NewRequest(originalResponseHeaders http.Header, url string) *writer { + return &writer{ + writer: w, originalResponseHeaders: originalResponseHeaders, url: url, result: make(http.Header), maxAge: -1, - varyValue: varyValue, } } // SetIsFallbackImage sets the Fallback-Image header to // indicate that the fallback image was used. -func (w *Writer) SetIsFallbackImage() { +func (w *writer) SetIsFallbackImage() { // We set maxAge to FallbackImageTTL if it's explicitly passed - if w.config.FallbackImageTTL < 0 { + if w.writer.config.FallbackImageTTL < 0 { return } // However, we should not overwrite existing value if set (or greater than ours) - if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL { - w.maxAge = w.config.FallbackImageTTL + if w.maxAge < 0 || w.maxAge > w.writer.config.FallbackImageTTL { + w.maxAge = w.writer.config.FallbackImageTTL } } // SetExpires sets the TTL from time -func (w *Writer) SetExpires(expires *time.Time) { +func (w *writer) SetExpires(expires *time.Time) { if expires == nil { return } @@ -70,13 +86,13 @@ func (w *Writer) SetExpires(expires *time.Time) { // If maxAge outlives expires or was not set, we'll use expires as maxAge. if w.maxAge < 0 || expires.Before(currentMaxAgeTime) { - w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds()))) + w.maxAge = min(w.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds()))) } } // SetLastModified sets the Last-Modified header from request -func (w *Writer) SetLastModified() { - if !w.config.LastModifiedEnabled { +func (w *writer) SetLastModified() { + if !w.writer.config.LastModifiedEnabled { return } @@ -89,14 +105,14 @@ func (w *Writer) SetLastModified() { } // SetVary sets the Vary header -func (w *Writer) SetVary() { - if len(w.varyValue) > 0 { - w.result.Set(httpheaders.Vary, w.varyValue) +func (w *writer) SetVary() { + if len(w.writer.varyValue) > 0 { + w.result.Set(httpheaders.Vary, w.writer.varyValue) } } // Passthrough copies specified headers from the original response headers to the response headers. -func (w *Writer) Passthrough(only []string) { +func (w *writer) Passthrough(only []string) { for _, key := range only { values := w.originalResponseHeaders.Values(key) @@ -108,7 +124,7 @@ func (w *Writer) Passthrough(only []string) { // CopyFrom copies specified headers from the headers object. Please note that // all the past operations may overwrite those values. -func (w *Writer) CopyFrom(headers http.Header, only []string) { +func (w *writer) CopyFrom(headers http.Header, only []string) { for _, key := range only { values := headers.Values(key) @@ -119,7 +135,7 @@ func (w *Writer) CopyFrom(headers http.Header, only []string) { } // SetContentLength sets the Content-Length header -func (w *Writer) SetContentLength(contentLength int) { +func (w *writer) SetContentLength(contentLength int) { if contentLength < 0 { return } @@ -128,25 +144,25 @@ func (w *Writer) SetContentLength(contentLength int) { } // SetContentType sets the Content-Type header -func (w *Writer) SetContentType(mime string) { +func (w *writer) SetContentType(mime string) { w.result.Set(httpheaders.ContentType, mime) } // writeCanonical sets the Link header with the canonical URL. // It is mandatory for any response if enabled in the configuration. -func (b *Writer) SetCanonical() { - if !b.config.SetCanonicalHeader { +func (w *writer) SetCanonical() { + if !w.writer.config.SetCanonicalHeader { return } - if strings.HasPrefix(b.url, "https://") || strings.HasPrefix(b.url, "http://") { - value := fmt.Sprintf(`<%s>; rel="canonical"`, b.url) - b.result.Set(httpheaders.Link, value) + if strings.HasPrefix(w.url, "https://") || strings.HasPrefix(w.url, "http://") { + value := fmt.Sprintf(`<%s>; rel="canonical"`, w.url) + w.result.Set(httpheaders.Link, value) } } // setCacheControl sets the Cache-Control header with the specified value. -func (w *Writer) setCacheControl(value int) bool { +func (w *writer) setCacheControl(value int) bool { if value <= 0 { return false } @@ -156,14 +172,14 @@ func (w *Writer) setCacheControl(value int) bool { } // setCacheControlNoCache sets the Cache-Control header to no-cache (default). -func (w *Writer) setCacheControlNoCache() { +func (w *writer) setCacheControlNoCache() { w.result.Set(httpheaders.CacheControl, "no-cache") } // setCacheControlPassthrough sets the Cache-Control header from the request // if passthrough is enabled in the configuration. -func (w *Writer) setCacheControlPassthrough() bool { - if !w.config.CacheControlPassthrough || w.maxAge > 0 { +func (w *writer) setCacheControlPassthrough() bool { + if !w.writer.config.CacheControlPassthrough || w.maxAge > 0 { return false } @@ -183,18 +199,18 @@ func (w *Writer) setCacheControlPassthrough() bool { } // setCSP sets the Content-Security-Policy header to prevent script execution. -func (w *Writer) setCSP() { +func (w *writer) setCSP() { w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'") } // Write writes the headers to the response writer. It does not overwrite // target headers, which were set outside the header writer. -func (w *Writer) Write(rw http.ResponseWriter) { +func (w *writer) Write(rw http.ResponseWriter) { // Then, let's try to set Cache-Control using priority order switch { case w.setCacheControl(w.maxAge): // First, try set explicit case w.setCacheControlPassthrough(): // Try to pick up from request headers - case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value + case w.setCacheControl(w.writer.config.DefaultTTL): // Fallback to default value default: w.setCacheControlNoCache() // By default we use no-cache } diff --git a/headerwriter/writer_test.go b/headerwriter/writer_test.go index 0c728cb6..ff5ea585 100644 --- a/headerwriter/writer_test.go +++ b/headerwriter/writer_test.go @@ -23,7 +23,7 @@ type writerTestCase struct { req http.Header res http.Header config Config - fn func(*Writer) + fn func(*writer) } func (s *HeaderWriterSuite) TestHeaderCases() { @@ -105,7 +105,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { SetCanonicalHeader: true, DefaultTTL: 3600, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetCanonical() }, }, @@ -134,7 +134,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { SetCanonicalHeader: false, DefaultTTL: 3600, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetCanonical() }, }, @@ -152,7 +152,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { LastModifiedEnabled: true, DefaultTTL: 3600, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetLastModified() }, }, @@ -167,7 +167,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { DefaultTTL: 3600, FallbackImageTTL: 1, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetIsFallbackImage() }, }, @@ -181,7 +181,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ DefaultTTL: math.MaxInt32, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetExpires(&expires) }, }, @@ -196,7 +196,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { DefaultTTL: math.MaxInt32, FallbackImageTTL: 600, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetIsFallbackImage() w.SetExpires(&shortExpires) }, @@ -213,7 +213,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { EnableClientHints: true, SetVaryAccept: true, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetVary() }, }, @@ -228,7 +228,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{}, - fn: func(w *Writer) { + fn: func(w *writer) { w.Passthrough([]string{"X-Test"}) }, }, @@ -241,7 +241,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{}, - fn: func(w *Writer) { + fn: func(w *writer) { h := http.Header{} h.Set("X-From", "baz") w.CopyFrom(h, []string{"X-From"}) @@ -256,7 +256,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{}, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetContentLength(123) }, }, @@ -269,7 +269,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, }, config: Config{}, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetContentType("image/png") }, }, @@ -283,7 +283,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ DefaultTTL: 3600, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetExpires(nil) }, }, @@ -298,7 +298,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ SetVaryAccept: true, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetVary() }, }, @@ -313,7 +313,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() { config: Config{ EnableClientHints: true, }, - fn: func(w *Writer) { + fn: func(w *writer) { w.SetVary() }, }, @@ -321,7 +321,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() { for _, tc := range tt { s.Run(tc.name, func() { - writer := New(&tc.config, tc.req, tc.url) + factory, err := New(&tc.config) + s.Require().NoError(err) + + writer := factory.NewRequest(tc.req, tc.url) if tc.fn != nil { tc.fn(writer) diff --git a/imagedata/download.go b/imagedata/download.go index 1cf44c12..f61c8d51 100644 --- a/imagedata/download.go +++ b/imagedata/download.go @@ -40,7 +40,12 @@ func initDownloading() error { return err } - Fetcher, err = imagefetcher.NewFetcher(ts, imagefetcher.NewConfigFromEnv()) + c, err := imagefetcher.NewDefaultConfig().LoadFromEnv() + if err != nil { + return ierrors.Wrap(err, 0, ierrors.WithPrefix("configuration error")) + } + + Fetcher, err = imagefetcher.NewFetcher(ts, c) if err != nil { return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create image fetcher")) } diff --git a/imagefetcher/config.go b/imagefetcher/config.go index 3477df6a..cf82d957 100644 --- a/imagefetcher/config.go +++ b/imagefetcher/config.go @@ -8,9 +8,20 @@ type Config struct { MaxRedirects int } -// NewConfigFromEnv creates a new Config instance from environment variables or defaults. -func NewConfigFromEnv() *Config { +// NewDefaultConfig returns a new Config instance with default values. +func NewDefaultConfig() *Config { return &Config{ - MaxRedirects: config.MaxRedirects, + MaxRedirects: 10, } } + +// LoadFromEnv loads config variables from env +func (c *Config) LoadFromEnv() (*Config, error) { + c.MaxRedirects = config.MaxRedirects + return c, nil +} + +// Validate checks config for errors +func (c *Config) Validate() error { + return nil +} diff --git a/imagefetcher/fetcher.go b/imagefetcher/fetcher.go index 3c80168e..618c6b4e 100644 --- a/imagefetcher/fetcher.go +++ b/imagefetcher/fetcher.go @@ -26,6 +26,10 @@ type Fetcher struct { // NewFetcher creates a new ImageFetcher with the provided transport func NewFetcher(transport *transport.Transport, config *Config) (*Fetcher, error) { + if err := config.Validate(); err != nil { + return nil, err + } + return &Fetcher{transport, config}, nil } diff --git a/main.go b/main.go index 0208f950..2b2352f4 100644 --- a/main.go +++ b/main.go @@ -137,8 +137,16 @@ func run(ctx context.Context) error { return err } - cfg := server.NewConfigFromEnv() - r := server.NewRouter(cfg) + cfg, err := server.NewDefaultConfig().LoadFromEnv() + if err != nil { + return err + } + + r, err := server.NewRouter(cfg) + if err != nil { + return err + } + s, err := server.Start(cancel, buildRouter(r)) if err != nil { return err diff --git a/processing_handler.go b/processing_handler.go index 86eac324..1da0e8ca 100644 --- a/processing_handler.go +++ b/processing_handler.go @@ -277,10 +277,29 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) err } if po.Raw { + // NOTE: This is temporary, there would be no categoryConfig once we + // finish with refactoring. // TODO: Move this up - cfg := stream.NewConfigFromEnv() - hwCfg := headerwriter.NewConfigFromEnv() - handler := stream.New(cfg, hwCfg, imagedata.Fetcher) + cfg, cerr := stream.NewDefaultConfig().LoadFromEnv() + if cerr != nil { + return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig)) + } + + hwc, cerr := headerwriter.NewDefaultConfig().LoadFromEnv() + if cerr != nil { + return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig)) + } + + hw, cerr := headerwriter.New(hwc) + if cerr != nil { + return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig)) + } + + handler, cerr := stream.New(cfg, hw, imagedata.Fetcher) + if cerr != nil { + return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig)) + } + return handler.Execute(ctx, r, imageURL, reqID, po, rw) } diff --git a/processing_handler_test.go b/processing_handler_test.go index 2b07270a..db9d0fd1 100644 --- a/processing_handler_test.go +++ b/processing_handler_test.go @@ -48,7 +48,11 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() { logrus.SetOutput(io.Discard) - s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv())) + cfg := server.NewDefaultConfig() + r, err := server.NewRouter(cfg) + s.Require().NoError(err) + + s.router = buildRouter(r) } func (s *ProcessingHandlerTestSuite) TeardownSuite() { diff --git a/server/config.go b/server/config.go index 8f9cef15..a9b9aaa5 100644 --- a/server/config.go +++ b/server/config.go @@ -1,6 +1,8 @@ package server import ( + "errors" + "fmt" "time" "github.com/imgproxy/imgproxy/v3/config" @@ -19,6 +21,7 @@ type Config struct { PathPrefix string // Path prefix for the server MaxClients int // Maximum number of concurrent clients ReadRequestTimeout time.Duration // Timeout for reading requests + WriteResponseTimeout time.Duration // Timeout for writing responses KeepAliveTimeout time.Duration // Timeout for keep-alive connections GracefulTimeout time.Duration // Timeout for graceful shutdown CORSAllowOrigin string // CORS allowed origin @@ -26,24 +29,70 @@ type Config struct { DevelopmentErrorsMode bool // Enable development mode for detailed error messages SocketReusePort bool // Enable SO_REUSEPORT socket option HealthCheckPath string // Health check path from config - WriteResponseTimeout time.Duration } -// NewConfigFromEnv creates a new Config instance from environment variables -func NewConfigFromEnv() *Config { +// NewDefaultConfig returns default config values +func NewDefaultConfig() *Config { return &Config{ - Network: config.Network, - Bind: config.Bind, - PathPrefix: config.PathPrefix, - MaxClients: config.MaxClients, - ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second, - KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second, + Network: "tcp", + Bind: ":8080", + PathPrefix: "", + MaxClients: 2048, + ReadRequestTimeout: 10 * time.Second, + KeepAliveTimeout: 10 * time.Second, + WriteResponseTimeout: 10 * time.Second, GracefulTimeout: gracefulTimeout, - CORSAllowOrigin: config.AllowOrigin, - Secret: config.Secret, - DevelopmentErrorsMode: config.DevelopmentErrorsMode, - SocketReusePort: config.SoReuseport, - HealthCheckPath: config.HealthCheckPath, - WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second, + CORSAllowOrigin: "", + Secret: "", + DevelopmentErrorsMode: false, + SocketReusePort: false, + HealthCheckPath: "", } } + +// LoadFromEnv overrides current values with environment variables +func (c *Config) LoadFromEnv() (*Config, error) { + c.Network = config.Network + c.Bind = config.Bind + c.PathPrefix = config.PathPrefix + c.MaxClients = config.MaxClients + c.ReadRequestTimeout = time.Duration(config.ReadRequestTimeout) * time.Second + c.KeepAliveTimeout = time.Duration(config.KeepAliveTimeout) * time.Second + c.GracefulTimeout = gracefulTimeout + c.CORSAllowOrigin = config.AllowOrigin + c.Secret = config.Secret + c.DevelopmentErrorsMode = config.DevelopmentErrorsMode + c.SocketReusePort = config.SoReuseport + c.HealthCheckPath = config.HealthCheckPath + + return c, nil +} + +// Validate checks that the config values are valid +func (c *Config) Validate() error { + if len(c.Bind) == 0 { + return errors.New("bind address is not defined") + } + + if c.MaxClients < 0 { + return fmt.Errorf("max clients number should be greater than or equal 0, now - %d", c.MaxClients) + } + + if c.ReadRequestTimeout <= 0 { + return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout) + } + + if c.WriteResponseTimeout <= 0 { + return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout) + } + + if c.KeepAliveTimeout < 0 { + return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout) + } + + if c.GracefulTimeout < 0 { + return fmt.Errorf("graceful timeout should be greater than or equal to 0, now - %d", c.GracefulTimeout) + } + + return nil +} diff --git a/server/router.go b/server/router.go index c8e44d60..3790f2ad 100644 --- a/server/router.go +++ b/server/router.go @@ -5,6 +5,7 @@ import ( "net" "net/http" "regexp" + "slices" "strings" nanoid "github.com/matoous/go-nanoid/v2" @@ -47,8 +48,12 @@ type Router struct { } // NewRouter creates a new Router instance -func NewRouter(config *Config) *Router { - return &Router{config: config} +func NewRouter(config *Config) (*Router, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + return &Router{config: config}, nil } // add adds an abitary route to the router @@ -57,14 +62,29 @@ func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, mi handler = m(handler) } - route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact} + newRoute := &route{ + method: method, + path: r.config.PathPrefix + prefix, + handler: handler, + exact: exact, + } - r.routes = append( - r.routes, - route, - ) + r.routes = append(r.routes, newRoute) - return route + // Sort routes by exact flag, exact routes go first in the + // same order they were added + slices.SortStableFunc(r.routes, func(a, b *route) int { + switch { + case a.exact == b.exact: + return 0 + case a.exact: + return -1 + default: + return 1 + } + }) + + return newRoute } // GET adds GET route diff --git a/server/router_test.go b/server/router_test.go index 4e704a1c..dd39187b 100644 --- a/server/router_test.go +++ b/server/router_test.go @@ -16,13 +16,13 @@ type RouterTestSuite struct { } func (s *RouterTestSuite) SetupTest() { - c := NewConfigFromEnv() - c.PathPrefix = "/api" - s.router = NewRouter(c) -} + c := NewDefaultConfig() -func TestRouterSuite(t *testing.T) { - suite.Run(t, new(RouterTestSuite)) + c.PathPrefix = "/api" + r, err := NewRouter(c) + s.Require().NoError(err) + + s.router = r } // TestHTTPMethods tests route methods registration and HTTP requests @@ -294,3 +294,23 @@ func (s *RouterTestSuite) TestReplaceIP() { }) } } + +// TestRouteOrder checks exact/non-exact insertion order +func (s *RouterTestSuite) TestRouteOrder() { + + h := func(reqID string, rw http.ResponseWriter, req *http.Request) error { + return nil + } + + s.router.GET("/test", false, h) + s.router.GET("/test/path", true, h) + s.router.GET("/test/path/nested", true, h) + + s.Require().Equal("/api/test/path", s.router.routes[0].path) + s.Require().Equal("/api/test/path/nested", s.router.routes[1].path) + s.Require().Equal("/api/test", s.router.routes[2].path) +} + +func TestRouterSuite(t *testing.T) { + suite.Run(t, new(RouterTestSuite)) +} diff --git a/server/server_test.go b/server/server_test.go index 0b4b2989..4cae08e7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/stretchr/testify/suite" ) @@ -21,10 +20,13 @@ type ServerTestSuite struct { } func (s *ServerTestSuite) SetupTest() { - config.Reset() - s.config = NewConfigFromEnv() + c := NewDefaultConfig() + + s.config = c s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment - s.blankRouter = NewRouter(s.config) + r, err := NewRouter(s.config) + s.Require().NoError(err) + s.blankRouter = r } func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { @@ -41,12 +43,11 @@ func (s *ServerTestSuite) TestStartServerWithInvalidBind() { cancelCalled.Store(true) } - invalidConfig := &Config{ - Network: "tcp", - Bind: "invalid-address", // Invalid address - } + invalidConfig := NewDefaultConfig() + invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address - r := NewRouter(invalidConfig) + r, err := NewRouter(invalidConfig) + s.Require().NoError(err) server, err := Start(cancelWrapper, r) @@ -109,10 +110,11 @@ func (s *ServerTestSuite) TestWithCORS() { for _, tt := range tests { s.Run(tt.name, func() { - config := &Config{ - CORSAllowOrigin: tt.corsAllowOrigin, - } - router := NewRouter(config) + config := NewDefaultConfig() + config.CORSAllowOrigin = tt.corsAllowOrigin + + router, err := NewRouter(config) + s.Require().NoError(err) wrappedHandler := router.WithCORS(s.mockHandler) @@ -154,10 +156,11 @@ func (s *ServerTestSuite) TestWithSecret() { for _, tt := range tests { s.Run(tt.name, func() { - config := &Config{ - Secret: tt.secret, - } - router := NewRouter(config) + config := NewDefaultConfig() + config.Secret = tt.secret + + router, err := NewRouter(config) + s.Require().NoError(err) wrappedHandler := router.WithSecret(s.mockHandler) @@ -167,7 +170,7 @@ func (s *ServerTestSuite) TestWithSecret() { } rw := httptest.NewRecorder() - err := wrappedHandler("test-req-id", rw, req) + err = wrappedHandler("test-req-id", rw, req) if tt.expectError { s.Require().Error(err)