IMG-54: NewDefaultConfig(), routes ordering exact/non-exact (#1504)

* NewDefaultConfig() + LoadFromEnv()

* Route order

* Changed route switch

* categoryConfig

* Use Default() in tests
This commit is contained in:
Victor Sokolov
2025-08-25 19:52:29 +02:00
committed by GitHub
parent ec566ce1c0
commit 6c9d26e8f5
17 changed files with 391 additions and 158 deletions

View File

@@ -17,7 +17,7 @@ const (
categoryDownload = "download" categoryDownload = "download"
categoryProcessing = "processing" categoryProcessing = "processing"
categoryIO = "IO" categoryIO = "IO"
categoryStreaming = "streaming" categoryConfig = "config(tmp)" // NOTE: THIS IS TEMPORARY
) )
type ( type (

View File

@@ -17,10 +17,10 @@ type Config struct {
PassthroughResponseHeaders []string PassthroughResponseHeaders []string
} }
// NewConfigFromEnv creates a new Config instance from environment variables // NewDefaultConfig returns a new Config instance with default values.
func NewConfigFromEnv() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
CookiePassthrough: config.CookiePassthrough, CookiePassthrough: false,
PassthroughRequestHeaders: []string{ PassthroughRequestHeaders: []string{
httpheaders.IfNoneMatch, httpheaders.IfNoneMatch,
httpheaders.IfModifiedSince, httpheaders.IfModifiedSince,
@@ -37,3 +37,14 @@ func NewConfigFromEnv() *Config {
}, },
} }
} }
// LoadFromEnv loads config variables from environment
func (c *Config) LoadFromEnv() (*Config, error) {
c.CookiePassthrough = config.CookiePassthrough
return c, nil
}
// Validate checks config for errors
func (c *Config) Validate() error {
return nil
}

View File

@@ -35,9 +35,9 @@ var (
// Handler handles image passthrough requests, allowing images to be streamed directly // Handler handles image passthrough requests, allowing images to be streamed directly
type Handler struct { type Handler struct {
fetcher *imagefetcher.Fetcher // Fetcher instance to handle image fetching config *Config // Configuration for the streamer
config *Config // Configuration for the streamer fetcher *imagefetcher.Fetcher // Fetcher instance to handle image fetching
hwConfig *headerwriter.Config // Configuration for header writing hw *headerwriter.Writer // Configured HeaderWriter instance
} }
// request holds the parameters and state for a single streaming request // request holds the parameters and state for a single streaming request
@@ -51,12 +51,16 @@ type request struct {
} }
// New creates new handler object // New creates new handler object
func New(config *Config, hwConfig *headerwriter.Config, fetcher *imagefetcher.Fetcher) *Handler { func New(config *Config, hw *headerwriter.Writer, fetcher *imagefetcher.Fetcher) (*Handler, error) {
return &Handler{ if err := config.Validate(); err != nil {
fetcher: fetcher, return nil, err
config: config,
hwConfig: hwConfig,
} }
return &Handler{
fetcher: fetcher,
config: config,
hw: hw,
}, nil
} }
// Stream handles the image passthrough request, streaming the image directly to the response writer // Stream handles the image passthrough request, streaming the image directly to the response writer
@@ -110,7 +114,8 @@ func (s *request) execute(ctx context.Context) error {
} }
// Output streaming response headers // Output streaming response headers
hw := headerwriter.New(s.handler.hwConfig, res.Header, s.imageURL) hw := s.handler.hw.NewRequest(res.Header, s.imageURL)
hw.Passthrough(s.handler.config.PassthroughResponseHeaders) // NOTE: priority? This is lowest as it was hw.Passthrough(s.handler.config.PassthroughResponseHeaders) // NOTE: priority? This is lowest as it was
hw.SetContentLength(int(res.ContentLength)) hw.SetContentLength(int(res.ContentLength))
hw.SetCanonical() hw.SetCanonical()

View File

@@ -52,10 +52,20 @@ func (s *HandlerTestSuite) SetupTest() {
tr, err := transport.NewTransport() tr, err := transport.NewTransport()
s.Require().NoError(err) s.Require().NoError(err)
fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) fc := imagefetcher.NewDefaultConfig()
fetcher, err := imagefetcher.NewFetcher(tr, fc)
s.Require().NoError(err) s.Require().NoError(err)
s.handler = New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hw, err := headerwriter.New(hwc)
s.Require().NoError(err)
h, err := New(cfg, hw, fetcher)
s.Require().NoError(err)
s.handler = h
} }
func (s *HandlerTestSuite) readTestFile(name string) []byte { func (s *HandlerTestSuite) readTestFile(name string) []byte {
@@ -207,8 +217,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
oneMinuteDelta = float64(time.Minute) oneMinuteDelta = float64(time.Minute)
) )
// Set this explicitly for testing purposes defaultTTL := 4242
config.TTL = 4242
testCases := []testCase{ testCases := []testCase{
{ {
@@ -248,7 +257,7 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) { validate: func(t *testing.T, res *http.Response) {
s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second) s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
}, },
}, },
// When expires is set in processing options, but not present in the response // When expires is set in processing options, but not present in the response
@@ -320,17 +329,13 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
timestampOffset: nil, timestampOffset: nil,
expectedStatusCode: 200, expectedStatusCode: 200,
validate: func(t *testing.T, res *http.Response) { validate: func(t *testing.T, res *http.Response) {
s.Require().Equal(s.maxAgeValue(res), time.Duration(config.TTL)*time.Second) s.Require().Equal(s.maxAgeValue(res), time.Duration(defaultTTL)*time.Second)
}, },
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
// Set config values for this test
config.CacheControlPassthrough = tc.cacheControlPassthrough
config.TTL = 4242 // Set consistent TTL for testing
data := s.readTestFile("test1.png") data := s.readTestFile("test1.png")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -345,10 +350,21 @@ func (s *HandlerTestSuite) TestHandlerCacheControl() {
tr, err := transport.NewTransport() tr, err := transport.NewTransport()
s.Require().NoError(err) s.Require().NoError(err)
fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) fc := imagefetcher.NewDefaultConfig()
fetcher, err := imagefetcher.NewFetcher(tr, fc)
s.Require().NoError(err) s.Require().NoError(err)
handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hwc.CacheControlPassthrough = tc.cacheControlPassthrough
hwc.DefaultTTL = 4242
hw, err := headerwriter.New(hwc)
s.Require().NoError(err)
handler, err := New(cfg, hw, fetcher)
s.Require().NoError(err)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@@ -424,20 +440,23 @@ func (s *HandlerTestSuite) TestHandlerErrorResponse() {
// TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service. // TestHandlerCookiePassthrough tests the cookie passthrough behavior of the streaming service.
func (s *HandlerTestSuite) TestHandlerCookiePassthrough() { func (s *HandlerTestSuite) TestHandlerCookiePassthrough() {
// Enable cookie passthrough for this test
config.CookiePassthrough = true
defer func() {
config.CookiePassthrough = false // Reset after test
}()
// Create new handler with updated config // Create new handler with updated config
tr, err := transport.NewTransport() tr, err := transport.NewTransport()
s.Require().NoError(err) s.Require().NoError(err)
fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) fc := imagefetcher.NewDefaultConfig()
fetcher, err := imagefetcher.NewFetcher(tr, fc)
s.Require().NoError(err) s.Require().NoError(err)
handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) cfg := NewDefaultConfig()
cfg.CookiePassthrough = true
hwc := headerwriter.NewDefaultConfig()
hw, err := headerwriter.New(hwc)
s.Require().NoError(err)
handler, err := New(cfg, hw, fetcher)
s.Require().NoError(err)
data := s.readTestFile("test1.png") data := s.readTestFile("test1.png")
@@ -478,16 +497,24 @@ func (s *HandlerTestSuite) TestHandlerCanonicalHeader() {
defer ts.Close() defer ts.Close()
for _, sc := range []bool{true, false} { for _, sc := range []bool{true, false} {
config.SetCanonicalHeader = sc
// Create new handler with updated config // Create new handler with updated config
tr, err := transport.NewTransport() tr, err := transport.NewTransport()
s.Require().NoError(err) s.Require().NoError(err)
fetcher, err := imagefetcher.NewFetcher(tr, imagefetcher.NewConfigFromEnv()) fc := imagefetcher.NewDefaultConfig()
fetcher, err := imagefetcher.NewFetcher(tr, fc)
s.Require().NoError(err) s.Require().NoError(err)
handler := New(NewConfigFromEnv(), headerwriter.NewConfigFromEnv(), fetcher) cfg := NewDefaultConfig()
hwc := headerwriter.NewDefaultConfig()
hwc.SetCanonicalHeader = sc
hw, err := headerwriter.New(hwc)
s.Require().NoError(err)
handler, err := New(cfg, hw, fetcher)
s.Require().NoError(err)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()

View File

@@ -1,6 +1,8 @@
package headerwriter package headerwriter
import ( import (
"fmt"
"github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config"
) )
@@ -15,20 +17,46 @@ type Config struct {
SetVaryAccept bool // Whether to include Accept in Vary header SetVaryAccept bool // Whether to include Accept in Vary header
} }
// NewConfigFromEnv creates a new Config instance from the current configuration // NewDefaultConfig returns a new Config instance with default values.
func NewConfigFromEnv() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
SetCanonicalHeader: config.SetCanonicalHeader, SetCanonicalHeader: false,
DefaultTTL: config.TTL, DefaultTTL: 31536000,
FallbackImageTTL: config.FallbackImageTTL, FallbackImageTTL: 0,
LastModifiedEnabled: config.LastModifiedEnabled, LastModifiedEnabled: false,
CacheControlPassthrough: config.CacheControlPassthrough, CacheControlPassthrough: false,
EnableClientHints: config.EnableClientHints, EnableClientHints: false,
SetVaryAccept: config.AutoWebp || SetVaryAccept: false,
config.EnforceWebp ||
config.AutoAvif ||
config.EnforceAvif ||
config.AutoJxl ||
config.EnforceJxl,
} }
} }
// LoadFromEnv overrides configuration variables from environment
func (c *Config) LoadFromEnv() (*Config, error) {
c.SetCanonicalHeader = config.SetCanonicalHeader
c.DefaultTTL = config.TTL
c.FallbackImageTTL = config.FallbackImageTTL
c.LastModifiedEnabled = config.LastModifiedEnabled
c.CacheControlPassthrough = config.CacheControlPassthrough
c.EnableClientHints = config.EnableClientHints
c.SetVaryAccept = config.AutoWebp ||
config.EnforceWebp ||
config.AutoAvif ||
config.EnforceAvif ||
config.AutoJxl ||
config.EnforceJxl
return c, nil
}
// Validate checks config for errors
func (c *Config) Validate() error {
if c.DefaultTTL < 0 {
return fmt.Errorf("image TTL should be greater than or equal to 0, now - %d", c.DefaultTTL)
}
if c.FallbackImageTTL < 0 {
return fmt.Errorf("fallback image TTL should be greater than or equal to 0, now - %d", c.FallbackImageTTL)
}
return nil
}

View File

@@ -11,18 +11,27 @@ import (
"github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/httpheaders"
) )
// Writer is a struct that builds HTTP response headers. // Writer is a struct that creates header writer factories.
type Writer struct { type Writer struct {
config *Config config *Config
varyValue string
}
// writer is a private struct that builds HTTP response headers for a specific request.
type writer struct {
writer *Writer
originalResponseHeaders http.Header // Original response headers originalResponseHeaders http.Header // Original response headers
result http.Header // Headers to be written to the response result http.Header // Headers to be written to the response
maxAge int // Current max age for Cache-Control header maxAge int // Current max age for Cache-Control header
url string // URL of the request, used for canonical header url string // URL of the request, used for canonical header
varyValue string // Vary header value
} }
// New creates a new HeaderBuilder instance with the provided origin headers and URL // New creates a new header writer factory with the provided config.
func New(config *Config, originalResponseHeaders http.Header, url string) *Writer { func New(config *Config) (*Writer, error) {
if err := config.Validate(); err != nil {
return nil, err
}
vary := make([]string, 0) vary := make([]string, 0)
if config.SetVaryAccept { if config.SetVaryAccept {
@@ -36,31 +45,38 @@ func New(config *Config, originalResponseHeaders http.Header, url string) *Write
varyValue := strings.Join(vary, ", ") varyValue := strings.Join(vary, ", ")
return &Writer{ return &Writer{
config: config, config: config,
varyValue: varyValue,
}, nil
}
// NewRequest creates a new header writer instance for a specific request with the provided origin headers and URL.
func (w *Writer) NewRequest(originalResponseHeaders http.Header, url string) *writer {
return &writer{
writer: w,
originalResponseHeaders: originalResponseHeaders, originalResponseHeaders: originalResponseHeaders,
url: url, url: url,
result: make(http.Header), result: make(http.Header),
maxAge: -1, maxAge: -1,
varyValue: varyValue,
} }
} }
// SetIsFallbackImage sets the Fallback-Image header to // SetIsFallbackImage sets the Fallback-Image header to
// indicate that the fallback image was used. // indicate that the fallback image was used.
func (w *Writer) SetIsFallbackImage() { func (w *writer) SetIsFallbackImage() {
// We set maxAge to FallbackImageTTL if it's explicitly passed // We set maxAge to FallbackImageTTL if it's explicitly passed
if w.config.FallbackImageTTL < 0 { if w.writer.config.FallbackImageTTL < 0 {
return return
} }
// However, we should not overwrite existing value if set (or greater than ours) // However, we should not overwrite existing value if set (or greater than ours)
if w.maxAge < 0 || w.maxAge > w.config.FallbackImageTTL { if w.maxAge < 0 || w.maxAge > w.writer.config.FallbackImageTTL {
w.maxAge = w.config.FallbackImageTTL w.maxAge = w.writer.config.FallbackImageTTL
} }
} }
// SetExpires sets the TTL from time // SetExpires sets the TTL from time
func (w *Writer) SetExpires(expires *time.Time) { func (w *writer) SetExpires(expires *time.Time) {
if expires == nil { if expires == nil {
return return
} }
@@ -70,13 +86,13 @@ func (w *Writer) SetExpires(expires *time.Time) {
// If maxAge outlives expires or was not set, we'll use expires as maxAge. // If maxAge outlives expires or was not set, we'll use expires as maxAge.
if w.maxAge < 0 || expires.Before(currentMaxAgeTime) { if w.maxAge < 0 || expires.Before(currentMaxAgeTime) {
w.maxAge = min(w.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds()))) w.maxAge = min(w.writer.config.DefaultTTL, max(0, int(time.Until(*expires).Seconds())))
} }
} }
// SetLastModified sets the Last-Modified header from request // SetLastModified sets the Last-Modified header from request
func (w *Writer) SetLastModified() { func (w *writer) SetLastModified() {
if !w.config.LastModifiedEnabled { if !w.writer.config.LastModifiedEnabled {
return return
} }
@@ -89,14 +105,14 @@ func (w *Writer) SetLastModified() {
} }
// SetVary sets the Vary header // SetVary sets the Vary header
func (w *Writer) SetVary() { func (w *writer) SetVary() {
if len(w.varyValue) > 0 { if len(w.writer.varyValue) > 0 {
w.result.Set(httpheaders.Vary, w.varyValue) w.result.Set(httpheaders.Vary, w.writer.varyValue)
} }
} }
// Passthrough copies specified headers from the original response headers to the response headers. // Passthrough copies specified headers from the original response headers to the response headers.
func (w *Writer) Passthrough(only []string) { func (w *writer) Passthrough(only []string) {
for _, key := range only { for _, key := range only {
values := w.originalResponseHeaders.Values(key) values := w.originalResponseHeaders.Values(key)
@@ -108,7 +124,7 @@ func (w *Writer) Passthrough(only []string) {
// CopyFrom copies specified headers from the headers object. Please note that // CopyFrom copies specified headers from the headers object. Please note that
// all the past operations may overwrite those values. // all the past operations may overwrite those values.
func (w *Writer) CopyFrom(headers http.Header, only []string) { func (w *writer) CopyFrom(headers http.Header, only []string) {
for _, key := range only { for _, key := range only {
values := headers.Values(key) values := headers.Values(key)
@@ -119,7 +135,7 @@ func (w *Writer) CopyFrom(headers http.Header, only []string) {
} }
// SetContentLength sets the Content-Length header // SetContentLength sets the Content-Length header
func (w *Writer) SetContentLength(contentLength int) { func (w *writer) SetContentLength(contentLength int) {
if contentLength < 0 { if contentLength < 0 {
return return
} }
@@ -128,25 +144,25 @@ func (w *Writer) SetContentLength(contentLength int) {
} }
// SetContentType sets the Content-Type header // SetContentType sets the Content-Type header
func (w *Writer) SetContentType(mime string) { func (w *writer) SetContentType(mime string) {
w.result.Set(httpheaders.ContentType, mime) w.result.Set(httpheaders.ContentType, mime)
} }
// writeCanonical sets the Link header with the canonical URL. // writeCanonical sets the Link header with the canonical URL.
// It is mandatory for any response if enabled in the configuration. // It is mandatory for any response if enabled in the configuration.
func (b *Writer) SetCanonical() { func (w *writer) SetCanonical() {
if !b.config.SetCanonicalHeader { if !w.writer.config.SetCanonicalHeader {
return return
} }
if strings.HasPrefix(b.url, "https://") || strings.HasPrefix(b.url, "http://") { if strings.HasPrefix(w.url, "https://") || strings.HasPrefix(w.url, "http://") {
value := fmt.Sprintf(`<%s>; rel="canonical"`, b.url) value := fmt.Sprintf(`<%s>; rel="canonical"`, w.url)
b.result.Set(httpheaders.Link, value) w.result.Set(httpheaders.Link, value)
} }
} }
// setCacheControl sets the Cache-Control header with the specified value. // setCacheControl sets the Cache-Control header with the specified value.
func (w *Writer) setCacheControl(value int) bool { func (w *writer) setCacheControl(value int) bool {
if value <= 0 { if value <= 0 {
return false return false
} }
@@ -156,14 +172,14 @@ func (w *Writer) setCacheControl(value int) bool {
} }
// setCacheControlNoCache sets the Cache-Control header to no-cache (default). // setCacheControlNoCache sets the Cache-Control header to no-cache (default).
func (w *Writer) setCacheControlNoCache() { func (w *writer) setCacheControlNoCache() {
w.result.Set(httpheaders.CacheControl, "no-cache") w.result.Set(httpheaders.CacheControl, "no-cache")
} }
// setCacheControlPassthrough sets the Cache-Control header from the request // setCacheControlPassthrough sets the Cache-Control header from the request
// if passthrough is enabled in the configuration. // if passthrough is enabled in the configuration.
func (w *Writer) setCacheControlPassthrough() bool { func (w *writer) setCacheControlPassthrough() bool {
if !w.config.CacheControlPassthrough || w.maxAge > 0 { if !w.writer.config.CacheControlPassthrough || w.maxAge > 0 {
return false return false
} }
@@ -183,18 +199,18 @@ func (w *Writer) setCacheControlPassthrough() bool {
} }
// setCSP sets the Content-Security-Policy header to prevent script execution. // setCSP sets the Content-Security-Policy header to prevent script execution.
func (w *Writer) setCSP() { func (w *writer) setCSP() {
w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'") w.result.Set(httpheaders.ContentSecurityPolicy, "script-src 'none'")
} }
// Write writes the headers to the response writer. It does not overwrite // Write writes the headers to the response writer. It does not overwrite
// target headers, which were set outside the header writer. // target headers, which were set outside the header writer.
func (w *Writer) Write(rw http.ResponseWriter) { func (w *writer) Write(rw http.ResponseWriter) {
// Then, let's try to set Cache-Control using priority order // Then, let's try to set Cache-Control using priority order
switch { switch {
case w.setCacheControl(w.maxAge): // First, try set explicit case w.setCacheControl(w.maxAge): // First, try set explicit
case w.setCacheControlPassthrough(): // Try to pick up from request headers case w.setCacheControlPassthrough(): // Try to pick up from request headers
case w.setCacheControl(w.config.DefaultTTL): // Fallback to default value case w.setCacheControl(w.writer.config.DefaultTTL): // Fallback to default value
default: default:
w.setCacheControlNoCache() // By default we use no-cache w.setCacheControlNoCache() // By default we use no-cache
} }

View File

@@ -23,7 +23,7 @@ type writerTestCase struct {
req http.Header req http.Header
res http.Header res http.Header
config Config config Config
fn func(*Writer) fn func(*writer)
} }
func (s *HeaderWriterSuite) TestHeaderCases() { func (s *HeaderWriterSuite) TestHeaderCases() {
@@ -105,7 +105,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
SetCanonicalHeader: true, SetCanonicalHeader: true,
DefaultTTL: 3600, DefaultTTL: 3600,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetCanonical() w.SetCanonical()
}, },
}, },
@@ -134,7 +134,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
SetCanonicalHeader: false, SetCanonicalHeader: false,
DefaultTTL: 3600, DefaultTTL: 3600,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetCanonical() w.SetCanonical()
}, },
}, },
@@ -152,7 +152,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
LastModifiedEnabled: true, LastModifiedEnabled: true,
DefaultTTL: 3600, DefaultTTL: 3600,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetLastModified() w.SetLastModified()
}, },
}, },
@@ -167,7 +167,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
DefaultTTL: 3600, DefaultTTL: 3600,
FallbackImageTTL: 1, FallbackImageTTL: 1,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetIsFallbackImage() w.SetIsFallbackImage()
}, },
}, },
@@ -181,7 +181,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{ config: Config{
DefaultTTL: math.MaxInt32, DefaultTTL: math.MaxInt32,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetExpires(&expires) w.SetExpires(&expires)
}, },
}, },
@@ -196,7 +196,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
DefaultTTL: math.MaxInt32, DefaultTTL: math.MaxInt32,
FallbackImageTTL: 600, FallbackImageTTL: 600,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetIsFallbackImage() w.SetIsFallbackImage()
w.SetExpires(&shortExpires) w.SetExpires(&shortExpires)
}, },
@@ -213,7 +213,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
EnableClientHints: true, EnableClientHints: true,
SetVaryAccept: true, SetVaryAccept: true,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetVary() w.SetVary()
}, },
}, },
@@ -228,7 +228,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
}, },
config: Config{}, config: Config{},
fn: func(w *Writer) { fn: func(w *writer) {
w.Passthrough([]string{"X-Test"}) w.Passthrough([]string{"X-Test"})
}, },
}, },
@@ -241,7 +241,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
}, },
config: Config{}, config: Config{},
fn: func(w *Writer) { fn: func(w *writer) {
h := http.Header{} h := http.Header{}
h.Set("X-From", "baz") h.Set("X-From", "baz")
w.CopyFrom(h, []string{"X-From"}) w.CopyFrom(h, []string{"X-From"})
@@ -256,7 +256,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
}, },
config: Config{}, config: Config{},
fn: func(w *Writer) { fn: func(w *writer) {
w.SetContentLength(123) w.SetContentLength(123)
}, },
}, },
@@ -269,7 +269,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"}, httpheaders.ContentSecurityPolicy: []string{"script-src 'none'"},
}, },
config: Config{}, config: Config{},
fn: func(w *Writer) { fn: func(w *writer) {
w.SetContentType("image/png") w.SetContentType("image/png")
}, },
}, },
@@ -283,7 +283,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{ config: Config{
DefaultTTL: 3600, DefaultTTL: 3600,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetExpires(nil) w.SetExpires(nil)
}, },
}, },
@@ -298,7 +298,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{ config: Config{
SetVaryAccept: true, SetVaryAccept: true,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetVary() w.SetVary()
}, },
}, },
@@ -313,7 +313,7 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
config: Config{ config: Config{
EnableClientHints: true, EnableClientHints: true,
}, },
fn: func(w *Writer) { fn: func(w *writer) {
w.SetVary() w.SetVary()
}, },
}, },
@@ -321,7 +321,10 @@ func (s *HeaderWriterSuite) TestHeaderCases() {
for _, tc := range tt { for _, tc := range tt {
s.Run(tc.name, func() { s.Run(tc.name, func() {
writer := New(&tc.config, tc.req, tc.url) factory, err := New(&tc.config)
s.Require().NoError(err)
writer := factory.NewRequest(tc.req, tc.url)
if tc.fn != nil { if tc.fn != nil {
tc.fn(writer) tc.fn(writer)

View File

@@ -40,7 +40,12 @@ func initDownloading() error {
return err return err
} }
Fetcher, err = imagefetcher.NewFetcher(ts, imagefetcher.NewConfigFromEnv()) c, err := imagefetcher.NewDefaultConfig().LoadFromEnv()
if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithPrefix("configuration error"))
}
Fetcher, err = imagefetcher.NewFetcher(ts, c)
if err != nil { if err != nil {
return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create image fetcher")) return ierrors.Wrap(err, 0, ierrors.WithPrefix("can't create image fetcher"))
} }

View File

@@ -8,9 +8,20 @@ type Config struct {
MaxRedirects int MaxRedirects int
} }
// NewConfigFromEnv creates a new Config instance from environment variables or defaults. // NewDefaultConfig returns a new Config instance with default values.
func NewConfigFromEnv() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
MaxRedirects: config.MaxRedirects, MaxRedirects: 10,
} }
} }
// LoadFromEnv loads config variables from env
func (c *Config) LoadFromEnv() (*Config, error) {
c.MaxRedirects = config.MaxRedirects
return c, nil
}
// Validate checks config for errors
func (c *Config) Validate() error {
return nil
}

View File

@@ -26,6 +26,10 @@ type Fetcher struct {
// NewFetcher creates a new ImageFetcher with the provided transport // NewFetcher creates a new ImageFetcher with the provided transport
func NewFetcher(transport *transport.Transport, config *Config) (*Fetcher, error) { func NewFetcher(transport *transport.Transport, config *Config) (*Fetcher, error) {
if err := config.Validate(); err != nil {
return nil, err
}
return &Fetcher{transport, config}, nil return &Fetcher{transport, config}, nil
} }

12
main.go
View File

@@ -137,8 +137,16 @@ func run(ctx context.Context) error {
return err return err
} }
cfg := server.NewConfigFromEnv() cfg, err := server.NewDefaultConfig().LoadFromEnv()
r := server.NewRouter(cfg) if err != nil {
return err
}
r, err := server.NewRouter(cfg)
if err != nil {
return err
}
s, err := server.Start(cancel, buildRouter(r)) s, err := server.Start(cancel, buildRouter(r))
if err != nil { if err != nil {
return err return err

View File

@@ -277,10 +277,29 @@ func handleProcessing(reqID string, rw http.ResponseWriter, r *http.Request) err
} }
if po.Raw { if po.Raw {
// NOTE: This is temporary, there would be no categoryConfig once we
// finish with refactoring.
// TODO: Move this up // TODO: Move this up
cfg := stream.NewConfigFromEnv() cfg, cerr := stream.NewDefaultConfig().LoadFromEnv()
hwCfg := headerwriter.NewConfigFromEnv() if cerr != nil {
handler := stream.New(cfg, hwCfg, imagedata.Fetcher) return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig))
}
hwc, cerr := headerwriter.NewDefaultConfig().LoadFromEnv()
if cerr != nil {
return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig))
}
hw, cerr := headerwriter.New(hwc)
if cerr != nil {
return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig))
}
handler, cerr := stream.New(cfg, hw, imagedata.Fetcher)
if cerr != nil {
return ierrors.Wrap(cerr, 0, ierrors.WithCategory(categoryConfig))
}
return handler.Execute(ctx, r, imageURL, reqID, po, rw) return handler.Execute(ctx, r, imageURL, reqID, po, rw)
} }

View File

@@ -48,7 +48,11 @@ func (s *ProcessingHandlerTestSuite) SetupSuite() {
logrus.SetOutput(io.Discard) logrus.SetOutput(io.Discard)
s.router = buildRouter(server.NewRouter(server.NewConfigFromEnv())) cfg := server.NewDefaultConfig()
r, err := server.NewRouter(cfg)
s.Require().NoError(err)
s.router = buildRouter(r)
} }
func (s *ProcessingHandlerTestSuite) TeardownSuite() { func (s *ProcessingHandlerTestSuite) TeardownSuite() {

View File

@@ -1,6 +1,8 @@
package server package server
import ( import (
"errors"
"fmt"
"time" "time"
"github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config"
@@ -19,6 +21,7 @@ type Config struct {
PathPrefix string // Path prefix for the server PathPrefix string // Path prefix for the server
MaxClients int // Maximum number of concurrent clients MaxClients int // Maximum number of concurrent clients
ReadRequestTimeout time.Duration // Timeout for reading requests ReadRequestTimeout time.Duration // Timeout for reading requests
WriteResponseTimeout time.Duration // Timeout for writing responses
KeepAliveTimeout time.Duration // Timeout for keep-alive connections KeepAliveTimeout time.Duration // Timeout for keep-alive connections
GracefulTimeout time.Duration // Timeout for graceful shutdown GracefulTimeout time.Duration // Timeout for graceful shutdown
CORSAllowOrigin string // CORS allowed origin CORSAllowOrigin string // CORS allowed origin
@@ -26,24 +29,70 @@ type Config struct {
DevelopmentErrorsMode bool // Enable development mode for detailed error messages DevelopmentErrorsMode bool // Enable development mode for detailed error messages
SocketReusePort bool // Enable SO_REUSEPORT socket option SocketReusePort bool // Enable SO_REUSEPORT socket option
HealthCheckPath string // Health check path from config HealthCheckPath string // Health check path from config
WriteResponseTimeout time.Duration
} }
// NewConfigFromEnv creates a new Config instance from environment variables // NewDefaultConfig returns default config values
func NewConfigFromEnv() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
Network: config.Network, Network: "tcp",
Bind: config.Bind, Bind: ":8080",
PathPrefix: config.PathPrefix, PathPrefix: "",
MaxClients: config.MaxClients, MaxClients: 2048,
ReadRequestTimeout: time.Duration(config.ReadRequestTimeout) * time.Second, ReadRequestTimeout: 10 * time.Second,
KeepAliveTimeout: time.Duration(config.KeepAliveTimeout) * time.Second, KeepAliveTimeout: 10 * time.Second,
WriteResponseTimeout: 10 * time.Second,
GracefulTimeout: gracefulTimeout, GracefulTimeout: gracefulTimeout,
CORSAllowOrigin: config.AllowOrigin, CORSAllowOrigin: "",
Secret: config.Secret, Secret: "",
DevelopmentErrorsMode: config.DevelopmentErrorsMode, DevelopmentErrorsMode: false,
SocketReusePort: config.SoReuseport, SocketReusePort: false,
HealthCheckPath: config.HealthCheckPath, HealthCheckPath: "",
WriteResponseTimeout: time.Duration(config.WriteResponseTimeout) * time.Second,
} }
} }
// LoadFromEnv overrides current values with environment variables
func (c *Config) LoadFromEnv() (*Config, error) {
c.Network = config.Network
c.Bind = config.Bind
c.PathPrefix = config.PathPrefix
c.MaxClients = config.MaxClients
c.ReadRequestTimeout = time.Duration(config.ReadRequestTimeout) * time.Second
c.KeepAliveTimeout = time.Duration(config.KeepAliveTimeout) * time.Second
c.GracefulTimeout = gracefulTimeout
c.CORSAllowOrigin = config.AllowOrigin
c.Secret = config.Secret
c.DevelopmentErrorsMode = config.DevelopmentErrorsMode
c.SocketReusePort = config.SoReuseport
c.HealthCheckPath = config.HealthCheckPath
return c, nil
}
// Validate checks that the config values are valid
func (c *Config) Validate() error {
if len(c.Bind) == 0 {
return errors.New("bind address is not defined")
}
if c.MaxClients < 0 {
return fmt.Errorf("max clients number should be greater than or equal 0, now - %d", c.MaxClients)
}
if c.ReadRequestTimeout <= 0 {
return fmt.Errorf("read request timeout should be greater than 0, now - %d", c.ReadRequestTimeout)
}
if c.WriteResponseTimeout <= 0 {
return fmt.Errorf("write response timeout should be greater than 0, now - %d", c.WriteResponseTimeout)
}
if c.KeepAliveTimeout < 0 {
return fmt.Errorf("keep alive timeout should be greater than or equal to 0, now - %d", c.KeepAliveTimeout)
}
if c.GracefulTimeout < 0 {
return fmt.Errorf("graceful timeout should be greater than or equal to 0, now - %d", c.GracefulTimeout)
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"net" "net"
"net/http" "net/http"
"regexp" "regexp"
"slices"
"strings" "strings"
nanoid "github.com/matoous/go-nanoid/v2" nanoid "github.com/matoous/go-nanoid/v2"
@@ -47,8 +48,12 @@ type Router struct {
} }
// NewRouter creates a new Router instance // NewRouter creates a new Router instance
func NewRouter(config *Config) *Router { func NewRouter(config *Config) (*Router, error) {
return &Router{config: config} if err := config.Validate(); err != nil {
return nil, err
}
return &Router{config: config}, nil
} }
// add adds an abitary route to the router // add adds an abitary route to the router
@@ -57,14 +62,29 @@ func (r *Router) add(method, prefix string, exact bool, handler RouteHandler, mi
handler = m(handler) handler = m(handler)
} }
route := &route{method: method, path: r.config.PathPrefix + prefix, handler: handler, exact: exact} newRoute := &route{
method: method,
path: r.config.PathPrefix + prefix,
handler: handler,
exact: exact,
}
r.routes = append( r.routes = append(r.routes, newRoute)
r.routes,
route,
)
return route // Sort routes by exact flag, exact routes go first in the
// same order they were added
slices.SortStableFunc(r.routes, func(a, b *route) int {
switch {
case a.exact == b.exact:
return 0
case a.exact:
return -1
default:
return 1
}
})
return newRoute
} }
// GET adds GET route // GET adds GET route

View File

@@ -16,13 +16,13 @@ type RouterTestSuite struct {
} }
func (s *RouterTestSuite) SetupTest() { func (s *RouterTestSuite) SetupTest() {
c := NewConfigFromEnv() c := NewDefaultConfig()
c.PathPrefix = "/api"
s.router = NewRouter(c)
}
func TestRouterSuite(t *testing.T) { c.PathPrefix = "/api"
suite.Run(t, new(RouterTestSuite)) r, err := NewRouter(c)
s.Require().NoError(err)
s.router = r
} }
// TestHTTPMethods tests route methods registration and HTTP requests // TestHTTPMethods tests route methods registration and HTTP requests
@@ -294,3 +294,23 @@ func (s *RouterTestSuite) TestReplaceIP() {
}) })
} }
} }
// TestRouteOrder checks exact/non-exact insertion order
func (s *RouterTestSuite) TestRouteOrder() {
h := func(reqID string, rw http.ResponseWriter, req *http.Request) error {
return nil
}
s.router.GET("/test", false, h)
s.router.GET("/test/path", true, h)
s.router.GET("/test/path/nested", true, h)
s.Require().Equal("/api/test/path", s.router.routes[0].path)
s.Require().Equal("/api/test/path/nested", s.router.routes[1].path)
s.Require().Equal("/api/test", s.router.routes[2].path)
}
func TestRouterSuite(t *testing.T) {
suite.Run(t, new(RouterTestSuite))
}

View File

@@ -9,7 +9,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/imgproxy/imgproxy/v3/config"
"github.com/imgproxy/imgproxy/v3/httpheaders" "github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@@ -21,10 +20,13 @@ type ServerTestSuite struct {
} }
func (s *ServerTestSuite) SetupTest() { func (s *ServerTestSuite) SetupTest() {
config.Reset() c := NewDefaultConfig()
s.config = NewConfigFromEnv()
s.config = c
s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment s.config.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
s.blankRouter = NewRouter(s.config) r, err := NewRouter(s.config)
s.Require().NoError(err)
s.blankRouter = r
} }
func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error { func (s *ServerTestSuite) mockHandler(reqID string, rw http.ResponseWriter, r *http.Request) error {
@@ -41,12 +43,11 @@ func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
cancelCalled.Store(true) cancelCalled.Store(true)
} }
invalidConfig := &Config{ invalidConfig := NewDefaultConfig()
Network: "tcp", invalidConfig.Bind = "-1.-1.-1.-1" // Invalid address
Bind: "invalid-address", // Invalid address
}
r := NewRouter(invalidConfig) r, err := NewRouter(invalidConfig)
s.Require().NoError(err)
server, err := Start(cancelWrapper, r) server, err := Start(cancelWrapper, r)
@@ -109,10 +110,11 @@ func (s *ServerTestSuite) TestWithCORS() {
for _, tt := range tests { for _, tt := range tests {
s.Run(tt.name, func() { s.Run(tt.name, func() {
config := &Config{ config := NewDefaultConfig()
CORSAllowOrigin: tt.corsAllowOrigin, config.CORSAllowOrigin = tt.corsAllowOrigin
}
router := NewRouter(config) router, err := NewRouter(config)
s.Require().NoError(err)
wrappedHandler := router.WithCORS(s.mockHandler) wrappedHandler := router.WithCORS(s.mockHandler)
@@ -154,10 +156,11 @@ func (s *ServerTestSuite) TestWithSecret() {
for _, tt := range tests { for _, tt := range tests {
s.Run(tt.name, func() { s.Run(tt.name, func() {
config := &Config{ config := NewDefaultConfig()
Secret: tt.secret, config.Secret = tt.secret
}
router := NewRouter(config) router, err := NewRouter(config)
s.Require().NoError(err)
wrappedHandler := router.WithSecret(s.mockHandler) wrappedHandler := router.WithSecret(s.mockHandler)
@@ -167,7 +170,7 @@ func (s *ServerTestSuite) TestWithSecret() {
} }
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
err := wrappedHandler("test-req-id", rw, req) err = wrappedHandler("test-req-id", rw, req)
if tt.expectError { if tt.expectError {
s.Require().Error(err) s.Require().Error(err)