diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 1349e6534..c0614f50c 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -8,6 +8,7 @@ import ( "net" "sync" "testing" + "testing/iotest" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -536,3 +537,27 @@ func TestBolt0008TestVectors(t *testing.T) { buf.Reset() } } + +// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after +// writing n bytes. +type timeoutWriter struct { + w io.Writer + n int64 +} + +func NewTimeoutWriter(w io.Writer, n int64) io.Writer { + return &timeoutWriter{w, n} +} + +func (t *timeoutWriter) Write(p []byte) (int, error) { + n := len(p) + if int64(n) > t.n { + n = int(t.n) + } + n, err := t.w.Write(p[:n]) + t.n -= int64(n) + if err == nil && t.n == 0 { + return n, iotest.ErrTimeout + } + return n, err +}