rpcwallet: allow remote signer to reconnect

Allow the remote signer to reconnect to the wallet after disconnecting,
as long as the remote signer reconnects within the timeout limit.

This is not a complete solution to the problem to allow the watch-only
node to stay online when the remote signer is disconnected, but is more
fault-tolerant than the current implementation as it allows the remote
to be temporarily disconnected.
This commit is contained in:
Viktor Tigerström
2025-03-21 19:14:33 +01:00
parent 8c82e5f7bf
commit a68d8d4823
2 changed files with 351 additions and 51 deletions

View File

@@ -8,6 +8,9 @@ import (
"sync/atomic"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/lightningnetwork/lnd/lnrpc/signrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lnutils"
@@ -77,16 +80,20 @@ type SignCoordinator struct {
// signer has errored, and we can no longer process responses.
receiveErrChan chan error
// doneReceiving is closed when either party terminates and signals to
// disconnected is closed when either party terminates and signals to
// any pending requests that we'll no longer process the response for
// that request.
doneReceiving chan struct{}
disconnected chan struct{}
// quit is closed when lnd is shutting down.
quit chan struct{}
// clientConnected is sent over when the remote signer connects.
clientConnected chan struct{}
// clientReady is closed and sent over when the remote signer is
// connected and ready to accept requests (after the initial handshake).
clientReady chan struct{}
// clientConnected is true if a remote signer is currently connected.
clientConnected bool
// requestTimeout is the maximum time we will wait for a response from
// the remote signer.
@@ -114,11 +121,14 @@ func NewSignCoordinator(requestTimeout time.Duration,
s := &SignCoordinator{
responses: respsMap,
receiveErrChan: make(chan error, 1),
doneReceiving: make(chan struct{}),
clientConnected: make(chan struct{}),
clientReady: make(chan struct{}),
clientConnected: false,
quit: make(chan struct{}),
requestTimeout: requestTimeout,
connectionTimeout: connectionTimeout,
// Note that the disconnected channel is not initialized here,
// as no code listens to it until the Run method has been called
// and set the field.
}
// We initialize the atomic nextRequestID to the handshakeRequestID, as
@@ -139,25 +149,33 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
s.mu.Unlock()
return ErrShuttingDown
case <-s.doneReceiving:
s.mu.Unlock()
return ErrNotConnected
default:
}
if s.clientConnected {
// If we already have a stream, we error out as we can only have
// one connection at a time.
return ErrMultipleConnections
}
s.wg.Add(1)
defer s.wg.Done()
// If we already have a stream, we error out as we can only have one
// connection throughout the lifetime of the SignCoordinator.
if s.stream != nil {
s.mu.Unlock()
return ErrMultipleConnections
}
s.clientConnected = true
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
// When `Run` returns, we set the clientConnected field to false
// to allow a new remote signer connection to be set up.
s.clientConnected = false
}()
s.stream = stream
s.disconnected = make(chan struct{})
defer close(s.disconnected)
s.mu.Unlock()
// The handshake must be completed before we can start sending requests
@@ -167,8 +185,18 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
return err
}
log.Infof("Remote signer connected")
close(s.clientConnected)
log.Infof("Remote signer connected and ready")
close(s.clientReady)
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
// We create a new clientReady channel, once this function
// has exited, to ensure that a new remote signer connection can
// be set up.
s.clientReady = make(chan struct{})
}()
// Now let's start the main receiving loop, which will receive all
// responses to our requests from the remote signer!
@@ -186,9 +214,6 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
case <-s.quit:
return ErrShuttingDown
case <-s.doneReceiving:
return ErrNotConnected
}
}
@@ -371,10 +396,6 @@ func (s *SignCoordinator) handshake(stream StreamServer) error {
func (s *SignCoordinator) StartReceiving() {
defer s.wg.Done()
// Signals to any ongoing requests that the remote signer is no longer
// connected.
defer close(s.doneReceiving)
for {
resp, err := s.stream.Recv()
if err != nil {
@@ -442,8 +463,16 @@ func (s *SignCoordinator) WaitUntilConnected() error {
func (s *SignCoordinator) waitUntilConnectedWithTimeout(
timeout time.Duration) error {
// As the Run method will redefine the clientReady channel once it
// returns, we need copy the pointer to the current clientReady channel
// to ensure that we're waiting for the correct channel, and to avoid
// a data race.
s.mu.Lock()
currentClientReady := s.clientReady
s.mu.Unlock()
select {
case <-s.clientConnected:
case <-currentClientReady:
return nil
case <-s.quit:
@@ -451,9 +480,6 @@ func (s *SignCoordinator) waitUntilConnectedWithTimeout(
case <-time.After(timeout):
return ErrConnectTimeout
case <-s.doneReceiving:
return ErrNotConnected
}
}
@@ -537,7 +563,7 @@ func (s *SignCoordinator) getResponse(requestID uint64,
return resp, nil
case <-s.doneReceiving:
case <-s.disconnected:
log.Debugf("Stopped waiting for remote signer response for "+
"request ID %d as the stream has been closed",
requestID)
@@ -854,8 +880,36 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration,
log.Tracef("Request content: %v", formatSignCoordinatorMsg(&req))
// reprocessOnDisconnect is a helper function that will be used to
// resend the request if the remote signer disconnects, through which
// we will wait for it to reconnect within the configured timeout, and
// then resend the request.
reprocessOnDisconnect := func() (R, error) {
var newTimeout time.Duration = noTimeout
if timeout != 0 {
newTimeout = timeout - time.Since(startTime)
if time.Since(startTime) > timeout {
return zero, ErrRequestTimeout
}
}
return processRequest[R](
s, newTimeout, generateRequest, extractResponse,
)
}
err = s.stream.Send(&req)
if err != nil {
st, isStatusError := status.FromError(err)
if isStatusError && st.Code() == codes.Unavailable {
// If the stream was closed due to the remote signer
// disconnecting, we will retry to process the request
// if the remote signer reconnects.
return reprocessOnDisconnect()
}
return zero, err
}
@@ -878,7 +932,12 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration,
resp, err = s.getResponse(reqID, s.requestTimeout)
}
if err != nil {
if errors.Is(err, ErrNotConnected) {
// If the remote signer disconnected while we were waiting for
// the response, we will retry to process the request if the
// remote signer reconnects.
return reprocessOnDisconnect()
} else if err != nil {
return zero, err
}

View File

@@ -6,6 +6,9 @@ import (
"testing"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/stretchr/testify/require"
@@ -19,6 +22,10 @@ type mockSCStream struct {
// sign coordinator to the remote signer.
sendChan chan *walletrpc.SignCoordinatorRequest
// sendErrorChan is used to simulate requests sent over the stream from the
// sign coordinator to the remote signer.
sendErrorChan chan error
// recvChan is used to simulate responses sent over the stream from the
// remote signer to the sign coordinator.
recvChan chan *walletrpc.SignCoordinatorResponse
@@ -32,10 +39,11 @@ type mockSCStream struct {
// newMockSCStream creates a new mock stream.
func newMockSCStream() *mockSCStream {
return &mockSCStream{
sendChan: make(chan *walletrpc.SignCoordinatorRequest, 1),
recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1),
cancelChan: make(chan struct{}),
ctx: context.Background(),
sendChan: make(chan *walletrpc.SignCoordinatorRequest),
sendErrorChan: make(chan error),
recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1),
cancelChan: make(chan struct{}),
ctx: context.Background(),
}
}
@@ -46,8 +54,8 @@ func (ms *mockSCStream) Send(req *walletrpc.SignCoordinatorRequest) error {
case ms.sendChan <- req:
return nil
case <-ms.cancelChan:
return ErrStreamCanceled
case err := <-ms.sendErrorChan:
return err
}
}
@@ -94,8 +102,19 @@ func (ms *mockSCStream) sendResponse(resp *walletrpc.SignCoordinatorResponse) {
func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream,
chan error) {
stream := newMockSCStream()
coordinator := NewSignCoordinator(2*time.Second, 3*time.Second)
stream, errChan := setupNewStream(t, coordinator)
return coordinator, stream, errChan
}
// setupNewStream sets up a new mock stream to simulate a communication with a
// remote signer. It also simulates the handshake between the passed sign
// coordinator and the remote signer.
func setupNewStream(t *testing.T,
coordinator *SignCoordinator) (*mockSCStream, chan error) {
stream := newMockSCStream()
errChan := make(chan error)
go func() {
@@ -136,7 +155,7 @@ func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream,
)
}
return coordinator, stream, errChan
return stream, errChan
}
// getRequest is a helper function to get a request that has been sent from
@@ -369,21 +388,43 @@ func TestPingTimeout(t *testing.T) {
coordinator, stream, _ := setupSignCoordinator(t)
// Simulate a Ping request that times out.
_, err := coordinator.Ping(1 * time.Second)
require.Equal(t, ErrRequestTimeout, err)
var wg sync.WaitGroup
// Simulate a Ping request that is expected to time out.
wg.Add(1)
go func() {
defer wg.Done()
// Note that the timeout is set to 1 second.
success, err := coordinator.Ping(1 * time.Second)
require.Equal(t, ErrRequestTimeout, err)
require.False(t, success)
}()
// Get the request sent over the mock stream.
req1, err := getRequest(stream)
require.NoError(t, err)
// Verify that the request has the expected request ID and that it's a
// Ping request.
require.Equal(t, uint64(2), req1.GetRequestId())
require.True(t, req1.GetPing())
// Verify that the coordinator has correctly set up a single response
// channel for the Ping request with the specific request ID.
require.Equal(t, coordinator.responses.Len(), 1)
_, ok := coordinator.responses.Load(uint64(2))
require.True(t, ok)
// Now wait for the request to time out.
wg.Wait()
// Verify that the responses map is empty after the timeout.
require.Equal(t, coordinator.responses.Len(), 0)
// Now let's simulate that the response is sent back after the request
// has timed out.
req, err := getRequest(stream)
require.NoError(t, err)
require.Equal(t, uint64(2), req.GetRequestId())
require.True(t, req.GetPing())
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
RefRequestId: 2,
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
@@ -718,7 +759,7 @@ func TestRemoteSignerDisconnects(t *testing.T) {
defer wg.Done()
success, err := coordinator.Ping(pingTimeout)
require.Equal(t, ErrNotConnected, err)
require.Equal(t, ErrConnectTimeout, err)
require.False(t, success)
}()
@@ -742,7 +783,7 @@ func TestRemoteSignerDisconnects(t *testing.T) {
stream.Cancel()
// This should cause the Run function to return the error that the
// stream was canceled with.
// stream was canceled.
err = <-runErrChan
require.Equal(t, ErrStreamCanceled, err)
@@ -752,11 +793,211 @@ func TestRemoteSignerDisconnects(t *testing.T) {
// Verify that the coordinator signals that it's done receiving
// responses after the stream is canceled, i.e. the StartReceiving
// function is no longer running.
<-coordinator.doneReceiving
<-coordinator.disconnected
// Ensure that the Ping request goroutine returned before the timeout
// was reached, which indicates that the request was canceled because
// the remote signer disconnected.
require.Greater(t, time.Since(startTime), pingTimeout)
require.Less(t, time.Since(startTime), pingTimeout+100*time.Millisecond)
// Verify the responses map is empty after all responses are received
require.Equal(t, coordinator.responses.Len(), 0)
}
// TestRemoteSignerReconnectsDuringResponseWait verifies that the sign
// coordinator correctly handles the scenario where the remote signer
// disconnects while a request is being processed and then reconnects. In this
// case, the sign coordinator should establish a new stream, reprocess the
// request, and ultimately receive a response.
func TestRemoteSignerReconnectsDuringResponseWait(t *testing.T) {
t.Parallel()
coordinator, stream, runErrChan := setupSignCoordinator(t)
pingTimeout := 3 * time.Second
startTime := time.Now()
var wg sync.WaitGroup
// Send a Ping request with a long timeout to ensure that the request
// will not time out before the remote signer disconnects.
wg.Add(1)
go func() {
defer wg.Done()
success, err := coordinator.Ping(pingTimeout)
require.NoError(t, err)
require.True(t, success)
}()
// Get the request sent over the mock stream.
req, err := getRequest(stream)
require.NoError(t, err)
// Verify that the request has the expected request ID and that it's a
// Ping request.
require.Equal(t, uint64(2), req.GetRequestId())
require.True(t, req.GetPing())
// Verify that the coordinator has correctly set up a single response
// channel for the Ping request with the specific request ID.
require.Equal(t, coordinator.responses.Len(), 1)
_, ok := coordinator.responses.Load(uint64(2))
require.True(t, ok)
// Now, lets simulate that the remote signer disconnects by canceling
// the stream, while the sign coordinator is still waiting for the Pong
// response for the request it sent.
stream.Cancel()
// This should cause the Run function to return the error that the
// stream was canceled.
err = <-runErrChan
require.Equal(t, ErrStreamCanceled, err)
// Verify that the coordinator signals that it's done receiving
// responses after the stream is canceled, i.e. the StartReceiving
// function is no longer running.
<-coordinator.disconnected
// Now let's simulate that the remote signer reconnects with a new
// stream.
stream, runErrChan = setupNewStream(t, coordinator)
// This should lead to that the sign coordinator resends the Ping
// request it's needs a response for over the new stream.
req, err = getRequest(stream)
require.NoError(t, err)
// Note that the request ID will be 3 for the resent request, as the
// coordinator will no longer wait for the response for the request with
// request ID 2.
require.Equal(t, uint64(3), req.GetRequestId())
require.True(t, req.GetPing())
// Verify that the coordinator has correctly set up a single response
// channel for the Ping request with the specific request ID.
require.Equal(t, coordinator.responses.Len(), 2)
_, ok = coordinator.responses.Load(uint64(3))
require.True(t, ok)
// Now let's send the Pong response for the resent Ping request.
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
RefRequestId: 3,
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
Pong: true,
},
})
// Ensure that the Ping request goroutine has finished.
wg.Wait()
// Ensure that the Ping request goroutine returned before the timeout
// was reached, which indicates that the request didn't time out as
// the remote signer reconnected in time and sent a response.
require.Less(t, time.Since(startTime), pingTimeout)
// Verify the responses map is empty after all responses are received
require.Equal(t, coordinator.responses.Len(), 0)
}
// TestRemoteSignerDisconnectsMidSend verifies that the sign coordinator
// correctly handles the scenario in which the remote signer disconnects while
// the sign coordinator is sending data over the stream (i.e., during the
// execution of the `Send` function) and then reconnects. In such a case, the
// sign coordinator should establish a new stream, reprocess the request, and
// eventually receive a response.
func TestRemoteSignerDisconnectsMidSend(t *testing.T) {
t.Parallel()
coordinator, stream, runErrChan := setupSignCoordinator(t)
pingTimeout := 3 * time.Second
startTime := time.Now()
var wg sync.WaitGroup
// Send a Ping request with a long timeout to ensure that the request
// will not time out before the remote signer disconnects.
wg.Add(1)
go func() {
defer wg.Done()
success, err := coordinator.Ping(pingTimeout)
require.NoError(t, err)
require.True(t, success)
}()
// Just wait slightly, to ensure that the Ping requests starts getting
// processed before we simulate the remote signer disconnecting.
<-time.After(10 * time.Millisecond)
// We simulate the remote signer disconnecting by canceling the
// stream.
stream.Cancel()
// This should cause the Run function to return the error that the
// stream was canceled with.
err := <-runErrChan
require.Equal(t, ErrStreamCanceled, err)
// Verify that the coordinator signals that it's done receiving
// responses after the stream is canceled, i.e. the StartReceiving
// function is no longer running.
<-coordinator.disconnected
// Now since the sign coordinator is still processing the requests, and
// we never extracted the request sent over the stream, the sign
// coordinator is stuck at the steam.Send function. We simulate this
// function now errors with the codes.Unavailable error, which is what
// the function would error with if the signer was disconnected during
// the send operation in a real scenario.
stream.sendErrorChan <- status.Errorf(
codes.Unavailable, "simulated unavailable error",
)
// Verify that the coordinator has correctly set up a single response
// channel for the Ping request with the specific request ID.
require.Equal(t, 1, coordinator.responses.Len())
// Now let's simulate that the remote signer reconnects with a new
// stream.
stream, runErrChan = setupNewStream(t, coordinator)
// This should lead to that the sign coordinator resends the Ping
// request it's needs a response for over the new stream.
req, err := getRequest(stream)
require.NoError(t, err)
// Note that the request ID will be 3 for the resent request, as the
// coordinator will no longer wait for the response for the request with
// request ID 2.
require.Equal(t, uint64(3), req.GetRequestId())
require.True(t, req.GetPing())
// Verify that the coordinator has correctly set up a single response
// channel for the Ping request with the specific request ID.
require.Equal(t, coordinator.responses.Len(), 2)
_, ok := coordinator.responses.Load(uint64(3))
require.True(t, ok)
// Now let's send the Pong response for the resent Ping request.
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
RefRequestId: 3,
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
Pong: true,
},
})
// Ensure that the Ping request goroutine has finished.
wg.Wait()
// Ensure that the Ping request goroutine returned before the timeout
// was reached, which indicates that the request didn't time out as
// the remote signer reconnected in time and sent a response.
require.Less(t, time.Since(startTime), pingTimeout)
// Verify the responses map is empty after all responses are received