diff --git a/server/internal/integrations/lark/ws_connector.go b/server/internal/integrations/lark/ws_connector.go index 9e4a56f95..228723fff 100644 --- a/server/internal/integrations/lark/ws_connector.go +++ b/server/internal/integrations/lark/ws_connector.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "net/url" "sync" "time" @@ -533,6 +534,14 @@ func (f FrameDecoderFunc) Decode(payload []byte, inst db.LarkInstallation) (Inbo // GorillaDialer is the production WSDialer. type GorillaDialer struct { Dialer *websocket.Dialer + + // Proxy is the proxy function for WebSocket connections. When nil + // (the zero value), the dialer defaults to http.ProxyFromEnvironment + // so standard HTTPS_PROXY / HTTP_PROXY / NO_PROXY environment + // variables are respected. Set Proxy to a non-nil func to override + // (e.g. for custom proxy auth or a fixed proxy URL). To disable proxy + // entirely, pass a func that returns (nil, nil). + Proxy func(*http.Request) (*url.URL, error) } func NewGorillaDialer() *GorillaDialer { @@ -548,7 +557,15 @@ func (g *GorillaDialer) DialContext(ctx context.Context, urlStr string, requestH if d == nil { d = websocket.DefaultDialer } - c, resp, err := d.DialContext(ctx, urlStr, requestHeader) + // Shallow copy so we don't mutate the shared dialer's Proxy field. + dd := *d + if g.Proxy != nil { + dd.Proxy = g.Proxy + } + if dd.Proxy == nil { + dd.Proxy = http.ProxyFromEnvironment + } + c, resp, err := dd.DialContext(ctx, urlStr, requestHeader) if err != nil { return nil, resp, err } diff --git a/server/internal/integrations/lark/ws_connector_test.go b/server/internal/integrations/lark/ws_connector_test.go index c1e0b65d3..3b875f951 100644 --- a/server/internal/integrations/lark/ws_connector_test.go +++ b/server/internal/integrations/lark/ws_connector_test.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "strconv" "sync" "sync/atomic" @@ -651,6 +652,70 @@ func TestWSConnectorReassemblesChunkedDataFrame(t *testing.T) { } } +func TestGorillaDialerPreservesConfiguredDialerProxy(t *testing.T) { + t.Parallel() + + proxyErr := errors.New("configured proxy refused") + d := &GorillaDialer{ + Dialer: &websocket.Dialer{ + Proxy: func(*http.Request) (*url.URL, error) { + return nil, proxyErr + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, _, err := d.DialContext(ctx, "ws://127.0.0.1:1", nil) + if !errors.Is(err, proxyErr) { + t.Fatalf("DialContext error = %v, want %v", err, proxyErr) + } +} + +func TestGorillaDialerProxyOverridesConfiguredDialerProxy(t *testing.T) { + t.Parallel() + + configuredProxyErr := errors.New("configured proxy refused") + overrideProxyErr := errors.New("override proxy refused") + d := &GorillaDialer{ + Dialer: &websocket.Dialer{ + Proxy: func(*http.Request) (*url.URL, error) { + return nil, configuredProxyErr + }, + }, + Proxy: func(*http.Request) (*url.URL, error) { + return nil, overrideProxyErr + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, _, err := d.DialContext(ctx, "ws://127.0.0.1:1", nil) + if !errors.Is(err, overrideProxyErr) { + t.Fatalf("DialContext error = %v, want %v", err, overrideProxyErr) + } + if errors.Is(err, configuredProxyErr) { + t.Fatalf("DialContext used configured proxy error %v instead of override", configuredProxyErr) + } +} + +func TestGorillaDialerProxyForwardsError(t *testing.T) { + t.Parallel() + + d := NewGorillaDialer() + proxyErr := errors.New("proxy refused") + d.Proxy = func(r *http.Request) (*url.URL, error) { + return nil, proxyErr + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, _, err := d.DialContext(ctx, "ws://127.0.0.1:1", nil) + if !errors.Is(err, proxyErr) { + t.Fatalf("DialContext error = %v, want %v", err, proxyErr) + } +} + func TestWSConnectorCredentialsErrorIsReturned(t *testing.T) { t.Parallel() credsErr := errors.New("decrypt failed")