From a68d8d4823b761909a09ebf84b1e663b8104155e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 21 Mar 2025 19:14:33 +0100 Subject: [PATCH] rpcwallet: allow remote signer to reconnect Allow the remote signer to reconnect to the wallet after disconnecting, as long as the remote signer reconnects within the timeout limit. This is not a complete solution to the problem to allow the watch-only node to stay online when the remote signer is disconnected, but is more fault-tolerant than the current implementation as it allows the remote to be temporarily disconnected. --- lnwallet/rpcwallet/sign_coordinator.go | 121 ++++++--- lnwallet/rpcwallet/sign_coordinator_test.go | 281 ++++++++++++++++++-- 2 files changed, 351 insertions(+), 51 deletions(-) diff --git a/lnwallet/rpcwallet/sign_coordinator.go b/lnwallet/rpcwallet/sign_coordinator.go index ff33e4b85..8fc8e60b9 100644 --- a/lnwallet/rpcwallet/sign_coordinator.go +++ b/lnwallet/rpcwallet/sign_coordinator.go @@ -8,6 +8,9 @@ import ( "sync/atomic" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lnutils" @@ -77,16 +80,20 @@ type SignCoordinator struct { // signer has errored, and we can no longer process responses. receiveErrChan chan error - // doneReceiving is closed when either party terminates and signals to + // disconnected is closed when either party terminates and signals to // any pending requests that we'll no longer process the response for // that request. - doneReceiving chan struct{} + disconnected chan struct{} // quit is closed when lnd is shutting down. quit chan struct{} - // clientConnected is sent over when the remote signer connects. - clientConnected chan struct{} + // clientReady is closed and sent over when the remote signer is + // connected and ready to accept requests (after the initial handshake). + clientReady chan struct{} + + // clientConnected is true if a remote signer is currently connected. + clientConnected bool // requestTimeout is the maximum time we will wait for a response from // the remote signer. @@ -114,11 +121,14 @@ func NewSignCoordinator(requestTimeout time.Duration, s := &SignCoordinator{ responses: respsMap, receiveErrChan: make(chan error, 1), - doneReceiving: make(chan struct{}), - clientConnected: make(chan struct{}), + clientReady: make(chan struct{}), + clientConnected: false, quit: make(chan struct{}), requestTimeout: requestTimeout, connectionTimeout: connectionTimeout, + // Note that the disconnected channel is not initialized here, + // as no code listens to it until the Run method has been called + // and set the field. } // We initialize the atomic nextRequestID to the handshakeRequestID, as @@ -139,25 +149,33 @@ func (s *SignCoordinator) Run(stream StreamServer) error { s.mu.Unlock() return ErrShuttingDown - case <-s.doneReceiving: - s.mu.Unlock() - return ErrNotConnected - default: } + if s.clientConnected { + // If we already have a stream, we error out as we can only have + // one connection at a time. + return ErrMultipleConnections + } + s.wg.Add(1) defer s.wg.Done() - // If we already have a stream, we error out as we can only have one - // connection throughout the lifetime of the SignCoordinator. - if s.stream != nil { - s.mu.Unlock() - return ErrMultipleConnections - } + s.clientConnected = true + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + + // When `Run` returns, we set the clientConnected field to false + // to allow a new remote signer connection to be set up. + s.clientConnected = false + }() s.stream = stream + s.disconnected = make(chan struct{}) + defer close(s.disconnected) + s.mu.Unlock() // The handshake must be completed before we can start sending requests @@ -167,8 +185,18 @@ func (s *SignCoordinator) Run(stream StreamServer) error { return err } - log.Infof("Remote signer connected") - close(s.clientConnected) + log.Infof("Remote signer connected and ready") + + close(s.clientReady) + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + + // We create a new clientReady channel, once this function + // has exited, to ensure that a new remote signer connection can + // be set up. + s.clientReady = make(chan struct{}) + }() // Now let's start the main receiving loop, which will receive all // responses to our requests from the remote signer! @@ -186,9 +214,6 @@ func (s *SignCoordinator) Run(stream StreamServer) error { case <-s.quit: return ErrShuttingDown - - case <-s.doneReceiving: - return ErrNotConnected } } @@ -371,10 +396,6 @@ func (s *SignCoordinator) handshake(stream StreamServer) error { func (s *SignCoordinator) StartReceiving() { defer s.wg.Done() - // Signals to any ongoing requests that the remote signer is no longer - // connected. - defer close(s.doneReceiving) - for { resp, err := s.stream.Recv() if err != nil { @@ -442,8 +463,16 @@ func (s *SignCoordinator) WaitUntilConnected() error { func (s *SignCoordinator) waitUntilConnectedWithTimeout( timeout time.Duration) error { + // As the Run method will redefine the clientReady channel once it + // returns, we need copy the pointer to the current clientReady channel + // to ensure that we're waiting for the correct channel, and to avoid + // a data race. + s.mu.Lock() + currentClientReady := s.clientReady + s.mu.Unlock() + select { - case <-s.clientConnected: + case <-currentClientReady: return nil case <-s.quit: @@ -451,9 +480,6 @@ func (s *SignCoordinator) waitUntilConnectedWithTimeout( case <-time.After(timeout): return ErrConnectTimeout - - case <-s.doneReceiving: - return ErrNotConnected } } @@ -537,7 +563,7 @@ func (s *SignCoordinator) getResponse(requestID uint64, return resp, nil - case <-s.doneReceiving: + case <-s.disconnected: log.Debugf("Stopped waiting for remote signer response for "+ "request ID %d as the stream has been closed", requestID) @@ -854,8 +880,36 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration, log.Tracef("Request content: %v", formatSignCoordinatorMsg(&req)) + // reprocessOnDisconnect is a helper function that will be used to + // resend the request if the remote signer disconnects, through which + // we will wait for it to reconnect within the configured timeout, and + // then resend the request. + reprocessOnDisconnect := func() (R, error) { + var newTimeout time.Duration = noTimeout + + if timeout != 0 { + newTimeout = timeout - time.Since(startTime) + + if time.Since(startTime) > timeout { + return zero, ErrRequestTimeout + } + } + + return processRequest[R]( + s, newTimeout, generateRequest, extractResponse, + ) + } + err = s.stream.Send(&req) if err != nil { + st, isStatusError := status.FromError(err) + if isStatusError && st.Code() == codes.Unavailable { + // If the stream was closed due to the remote signer + // disconnecting, we will retry to process the request + // if the remote signer reconnects. + return reprocessOnDisconnect() + } + return zero, err } @@ -878,7 +932,12 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration, resp, err = s.getResponse(reqID, s.requestTimeout) } - if err != nil { + if errors.Is(err, ErrNotConnected) { + // If the remote signer disconnected while we were waiting for + // the response, we will retry to process the request if the + // remote signer reconnects. + return reprocessOnDisconnect() + } else if err != nil { return zero, err } diff --git a/lnwallet/rpcwallet/sign_coordinator_test.go b/lnwallet/rpcwallet/sign_coordinator_test.go index e394e08c7..73d71b679 100644 --- a/lnwallet/rpcwallet/sign_coordinator_test.go +++ b/lnwallet/rpcwallet/sign_coordinator_test.go @@ -6,6 +6,9 @@ import ( "testing" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/stretchr/testify/require" @@ -19,6 +22,10 @@ type mockSCStream struct { // sign coordinator to the remote signer. sendChan chan *walletrpc.SignCoordinatorRequest + // sendErrorChan is used to simulate requests sent over the stream from the + // sign coordinator to the remote signer. + sendErrorChan chan error + // recvChan is used to simulate responses sent over the stream from the // remote signer to the sign coordinator. recvChan chan *walletrpc.SignCoordinatorResponse @@ -32,10 +39,11 @@ type mockSCStream struct { // newMockSCStream creates a new mock stream. func newMockSCStream() *mockSCStream { return &mockSCStream{ - sendChan: make(chan *walletrpc.SignCoordinatorRequest, 1), - recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1), - cancelChan: make(chan struct{}), - ctx: context.Background(), + sendChan: make(chan *walletrpc.SignCoordinatorRequest), + sendErrorChan: make(chan error), + recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1), + cancelChan: make(chan struct{}), + ctx: context.Background(), } } @@ -46,8 +54,8 @@ func (ms *mockSCStream) Send(req *walletrpc.SignCoordinatorRequest) error { case ms.sendChan <- req: return nil - case <-ms.cancelChan: - return ErrStreamCanceled + case err := <-ms.sendErrorChan: + return err } } @@ -94,8 +102,19 @@ func (ms *mockSCStream) sendResponse(resp *walletrpc.SignCoordinatorResponse) { func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream, chan error) { - stream := newMockSCStream() coordinator := NewSignCoordinator(2*time.Second, 3*time.Second) + stream, errChan := setupNewStream(t, coordinator) + + return coordinator, stream, errChan +} + +// setupNewStream sets up a new mock stream to simulate a communication with a +// remote signer. It also simulates the handshake between the passed sign +// coordinator and the remote signer. +func setupNewStream(t *testing.T, + coordinator *SignCoordinator) (*mockSCStream, chan error) { + + stream := newMockSCStream() errChan := make(chan error) go func() { @@ -136,7 +155,7 @@ func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream, ) } - return coordinator, stream, errChan + return stream, errChan } // getRequest is a helper function to get a request that has been sent from @@ -369,21 +388,43 @@ func TestPingTimeout(t *testing.T) { coordinator, stream, _ := setupSignCoordinator(t) - // Simulate a Ping request that times out. - _, err := coordinator.Ping(1 * time.Second) - require.Equal(t, ErrRequestTimeout, err) + var wg sync.WaitGroup + + // Simulate a Ping request that is expected to time out. + wg.Add(1) + + go func() { + defer wg.Done() + + // Note that the timeout is set to 1 second. + success, err := coordinator.Ping(1 * time.Second) + require.Equal(t, ErrRequestTimeout, err) + require.False(t, success) + }() + + // Get the request sent over the mock stream. + req1, err := getRequest(stream) + require.NoError(t, err) + + // Verify that the request has the expected request ID and that it's a + // Ping request. + require.Equal(t, uint64(2), req1.GetRequestId()) + require.True(t, req1.GetPing()) + + // Verify that the coordinator has correctly set up a single response + // channel for the Ping request with the specific request ID. + require.Equal(t, coordinator.responses.Len(), 1) + _, ok := coordinator.responses.Load(uint64(2)) + require.True(t, ok) + + // Now wait for the request to time out. + wg.Wait() // Verify that the responses map is empty after the timeout. require.Equal(t, coordinator.responses.Len(), 0) // Now let's simulate that the response is sent back after the request // has timed out. - req, err := getRequest(stream) - require.NoError(t, err) - - require.Equal(t, uint64(2), req.GetRequestId()) - require.True(t, req.GetPing()) - stream.sendResponse(&walletrpc.SignCoordinatorResponse{ RefRequestId: 2, SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{ @@ -718,7 +759,7 @@ func TestRemoteSignerDisconnects(t *testing.T) { defer wg.Done() success, err := coordinator.Ping(pingTimeout) - require.Equal(t, ErrNotConnected, err) + require.Equal(t, ErrConnectTimeout, err) require.False(t, success) }() @@ -742,7 +783,7 @@ func TestRemoteSignerDisconnects(t *testing.T) { stream.Cancel() // This should cause the Run function to return the error that the - // stream was canceled with. + // stream was canceled. err = <-runErrChan require.Equal(t, ErrStreamCanceled, err) @@ -752,11 +793,211 @@ func TestRemoteSignerDisconnects(t *testing.T) { // Verify that the coordinator signals that it's done receiving // responses after the stream is canceled, i.e. the StartReceiving // function is no longer running. - <-coordinator.doneReceiving + <-coordinator.disconnected // Ensure that the Ping request goroutine returned before the timeout // was reached, which indicates that the request was canceled because // the remote signer disconnected. + require.Greater(t, time.Since(startTime), pingTimeout) + require.Less(t, time.Since(startTime), pingTimeout+100*time.Millisecond) + + // Verify the responses map is empty after all responses are received + require.Equal(t, coordinator.responses.Len(), 0) +} + +// TestRemoteSignerReconnectsDuringResponseWait verifies that the sign +// coordinator correctly handles the scenario where the remote signer +// disconnects while a request is being processed and then reconnects. In this +// case, the sign coordinator should establish a new stream, reprocess the +// request, and ultimately receive a response. +func TestRemoteSignerReconnectsDuringResponseWait(t *testing.T) { + t.Parallel() + + coordinator, stream, runErrChan := setupSignCoordinator(t) + + pingTimeout := 3 * time.Second + startTime := time.Now() + + var wg sync.WaitGroup + + // Send a Ping request with a long timeout to ensure that the request + // will not time out before the remote signer disconnects. + wg.Add(1) + + go func() { + defer wg.Done() + + success, err := coordinator.Ping(pingTimeout) + require.NoError(t, err) + require.True(t, success) + }() + + // Get the request sent over the mock stream. + req, err := getRequest(stream) + require.NoError(t, err) + + // Verify that the request has the expected request ID and that it's a + // Ping request. + require.Equal(t, uint64(2), req.GetRequestId()) + require.True(t, req.GetPing()) + + // Verify that the coordinator has correctly set up a single response + // channel for the Ping request with the specific request ID. + require.Equal(t, coordinator.responses.Len(), 1) + _, ok := coordinator.responses.Load(uint64(2)) + require.True(t, ok) + + // Now, lets simulate that the remote signer disconnects by canceling + // the stream, while the sign coordinator is still waiting for the Pong + // response for the request it sent. + stream.Cancel() + + // This should cause the Run function to return the error that the + // stream was canceled. + err = <-runErrChan + require.Equal(t, ErrStreamCanceled, err) + + // Verify that the coordinator signals that it's done receiving + // responses after the stream is canceled, i.e. the StartReceiving + // function is no longer running. + <-coordinator.disconnected + + // Now let's simulate that the remote signer reconnects with a new + // stream. + stream, runErrChan = setupNewStream(t, coordinator) + + // This should lead to that the sign coordinator resends the Ping + // request it's needs a response for over the new stream. + req, err = getRequest(stream) + require.NoError(t, err) + + // Note that the request ID will be 3 for the resent request, as the + // coordinator will no longer wait for the response for the request with + // request ID 2. + require.Equal(t, uint64(3), req.GetRequestId()) + require.True(t, req.GetPing()) + + // Verify that the coordinator has correctly set up a single response + // channel for the Ping request with the specific request ID. + require.Equal(t, coordinator.responses.Len(), 2) + _, ok = coordinator.responses.Load(uint64(3)) + require.True(t, ok) + + // Now let's send the Pong response for the resent Ping request. + stream.sendResponse(&walletrpc.SignCoordinatorResponse{ + RefRequestId: 3, + SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{ + Pong: true, + }, + }) + + // Ensure that the Ping request goroutine has finished. + wg.Wait() + + // Ensure that the Ping request goroutine returned before the timeout + // was reached, which indicates that the request didn't time out as + // the remote signer reconnected in time and sent a response. + require.Less(t, time.Since(startTime), pingTimeout) + + // Verify the responses map is empty after all responses are received + require.Equal(t, coordinator.responses.Len(), 0) +} + +// TestRemoteSignerDisconnectsMidSend verifies that the sign coordinator +// correctly handles the scenario in which the remote signer disconnects while +// the sign coordinator is sending data over the stream (i.e., during the +// execution of the `Send` function) and then reconnects. In such a case, the +// sign coordinator should establish a new stream, reprocess the request, and +// eventually receive a response. +func TestRemoteSignerDisconnectsMidSend(t *testing.T) { + t.Parallel() + + coordinator, stream, runErrChan := setupSignCoordinator(t) + + pingTimeout := 3 * time.Second + startTime := time.Now() + + var wg sync.WaitGroup + + // Send a Ping request with a long timeout to ensure that the request + // will not time out before the remote signer disconnects. + wg.Add(1) + + go func() { + defer wg.Done() + + success, err := coordinator.Ping(pingTimeout) + require.NoError(t, err) + require.True(t, success) + }() + + // Just wait slightly, to ensure that the Ping requests starts getting + // processed before we simulate the remote signer disconnecting. + <-time.After(10 * time.Millisecond) + + // We simulate the remote signer disconnecting by canceling the + // stream. + stream.Cancel() + + // This should cause the Run function to return the error that the + // stream was canceled with. + err := <-runErrChan + require.Equal(t, ErrStreamCanceled, err) + + // Verify that the coordinator signals that it's done receiving + // responses after the stream is canceled, i.e. the StartReceiving + // function is no longer running. + <-coordinator.disconnected + + // Now since the sign coordinator is still processing the requests, and + // we never extracted the request sent over the stream, the sign + // coordinator is stuck at the steam.Send function. We simulate this + // function now errors with the codes.Unavailable error, which is what + // the function would error with if the signer was disconnected during + // the send operation in a real scenario. + stream.sendErrorChan <- status.Errorf( + codes.Unavailable, "simulated unavailable error", + ) + + // Verify that the coordinator has correctly set up a single response + // channel for the Ping request with the specific request ID. + require.Equal(t, 1, coordinator.responses.Len()) + + // Now let's simulate that the remote signer reconnects with a new + // stream. + stream, runErrChan = setupNewStream(t, coordinator) + + // This should lead to that the sign coordinator resends the Ping + // request it's needs a response for over the new stream. + req, err := getRequest(stream) + require.NoError(t, err) + + // Note that the request ID will be 3 for the resent request, as the + // coordinator will no longer wait for the response for the request with + // request ID 2. + require.Equal(t, uint64(3), req.GetRequestId()) + require.True(t, req.GetPing()) + + // Verify that the coordinator has correctly set up a single response + // channel for the Ping request with the specific request ID. + require.Equal(t, coordinator.responses.Len(), 2) + _, ok := coordinator.responses.Load(uint64(3)) + require.True(t, ok) + + // Now let's send the Pong response for the resent Ping request. + stream.sendResponse(&walletrpc.SignCoordinatorResponse{ + RefRequestId: 3, + SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{ + Pong: true, + }, + }) + + // Ensure that the Ping request goroutine has finished. + wg.Wait() + + // Ensure that the Ping request goroutine returned before the timeout + // was reached, which indicates that the request didn't time out as + // the remote signer reconnected in time and sent a response. require.Less(t, time.Since(startTime), pingTimeout) // Verify the responses map is empty after all responses are received