mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-10 20:13:29 +02:00
rpcwallet: allow remote signer to reconnect
Allow the remote signer to reconnect to the wallet after disconnecting, as long as the remote signer reconnects within the timeout limit. This is not a complete solution to the problem to allow the watch-only node to stay online when the remote signer is disconnected, but is more fault-tolerant than the current implementation as it allows the remote to be temporarily disconnected.
This commit is contained in:
@@ -8,6 +8,9 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/lnrpc/signrpc"
|
"github.com/lightningnetwork/lnd/lnrpc/signrpc"
|
||||||
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
|
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
|
||||||
"github.com/lightningnetwork/lnd/lnutils"
|
"github.com/lightningnetwork/lnd/lnutils"
|
||||||
@@ -77,16 +80,20 @@ type SignCoordinator struct {
|
|||||||
// signer has errored, and we can no longer process responses.
|
// signer has errored, and we can no longer process responses.
|
||||||
receiveErrChan chan error
|
receiveErrChan chan error
|
||||||
|
|
||||||
// doneReceiving is closed when either party terminates and signals to
|
// disconnected is closed when either party terminates and signals to
|
||||||
// any pending requests that we'll no longer process the response for
|
// any pending requests that we'll no longer process the response for
|
||||||
// that request.
|
// that request.
|
||||||
doneReceiving chan struct{}
|
disconnected chan struct{}
|
||||||
|
|
||||||
// quit is closed when lnd is shutting down.
|
// quit is closed when lnd is shutting down.
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
|
|
||||||
// clientConnected is sent over when the remote signer connects.
|
// clientReady is closed and sent over when the remote signer is
|
||||||
clientConnected chan struct{}
|
// connected and ready to accept requests (after the initial handshake).
|
||||||
|
clientReady chan struct{}
|
||||||
|
|
||||||
|
// clientConnected is true if a remote signer is currently connected.
|
||||||
|
clientConnected bool
|
||||||
|
|
||||||
// requestTimeout is the maximum time we will wait for a response from
|
// requestTimeout is the maximum time we will wait for a response from
|
||||||
// the remote signer.
|
// the remote signer.
|
||||||
@@ -114,11 +121,14 @@ func NewSignCoordinator(requestTimeout time.Duration,
|
|||||||
s := &SignCoordinator{
|
s := &SignCoordinator{
|
||||||
responses: respsMap,
|
responses: respsMap,
|
||||||
receiveErrChan: make(chan error, 1),
|
receiveErrChan: make(chan error, 1),
|
||||||
doneReceiving: make(chan struct{}),
|
clientReady: make(chan struct{}),
|
||||||
clientConnected: make(chan struct{}),
|
clientConnected: false,
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
requestTimeout: requestTimeout,
|
requestTimeout: requestTimeout,
|
||||||
connectionTimeout: connectionTimeout,
|
connectionTimeout: connectionTimeout,
|
||||||
|
// Note that the disconnected channel is not initialized here,
|
||||||
|
// as no code listens to it until the Run method has been called
|
||||||
|
// and set the field.
|
||||||
}
|
}
|
||||||
|
|
||||||
// We initialize the atomic nextRequestID to the handshakeRequestID, as
|
// We initialize the atomic nextRequestID to the handshakeRequestID, as
|
||||||
@@ -139,25 +149,33 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return ErrShuttingDown
|
return ErrShuttingDown
|
||||||
|
|
||||||
case <-s.doneReceiving:
|
|
||||||
s.mu.Unlock()
|
|
||||||
return ErrNotConnected
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.clientConnected {
|
||||||
|
// If we already have a stream, we error out as we can only have
|
||||||
|
// one connection at a time.
|
||||||
|
return ErrMultipleConnections
|
||||||
|
}
|
||||||
|
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
|
|
||||||
// If we already have a stream, we error out as we can only have one
|
s.clientConnected = true
|
||||||
// connection throughout the lifetime of the SignCoordinator.
|
defer func() {
|
||||||
if s.stream != nil {
|
s.mu.Lock()
|
||||||
s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
return ErrMultipleConnections
|
|
||||||
}
|
// When `Run` returns, we set the clientConnected field to false
|
||||||
|
// to allow a new remote signer connection to be set up.
|
||||||
|
s.clientConnected = false
|
||||||
|
}()
|
||||||
|
|
||||||
s.stream = stream
|
s.stream = stream
|
||||||
|
|
||||||
|
s.disconnected = make(chan struct{})
|
||||||
|
defer close(s.disconnected)
|
||||||
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// The handshake must be completed before we can start sending requests
|
// The handshake must be completed before we can start sending requests
|
||||||
@@ -167,8 +185,18 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Remote signer connected")
|
log.Infof("Remote signer connected and ready")
|
||||||
close(s.clientConnected)
|
|
||||||
|
close(s.clientReady)
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// We create a new clientReady channel, once this function
|
||||||
|
// has exited, to ensure that a new remote signer connection can
|
||||||
|
// be set up.
|
||||||
|
s.clientReady = make(chan struct{})
|
||||||
|
}()
|
||||||
|
|
||||||
// Now let's start the main receiving loop, which will receive all
|
// Now let's start the main receiving loop, which will receive all
|
||||||
// responses to our requests from the remote signer!
|
// responses to our requests from the remote signer!
|
||||||
@@ -186,9 +214,6 @@ func (s *SignCoordinator) Run(stream StreamServer) error {
|
|||||||
|
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
return ErrShuttingDown
|
return ErrShuttingDown
|
||||||
|
|
||||||
case <-s.doneReceiving:
|
|
||||||
return ErrNotConnected
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,10 +396,6 @@ func (s *SignCoordinator) handshake(stream StreamServer) error {
|
|||||||
func (s *SignCoordinator) StartReceiving() {
|
func (s *SignCoordinator) StartReceiving() {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
|
|
||||||
// Signals to any ongoing requests that the remote signer is no longer
|
|
||||||
// connected.
|
|
||||||
defer close(s.doneReceiving)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
resp, err := s.stream.Recv()
|
resp, err := s.stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -442,8 +463,16 @@ func (s *SignCoordinator) WaitUntilConnected() error {
|
|||||||
func (s *SignCoordinator) waitUntilConnectedWithTimeout(
|
func (s *SignCoordinator) waitUntilConnectedWithTimeout(
|
||||||
timeout time.Duration) error {
|
timeout time.Duration) error {
|
||||||
|
|
||||||
|
// As the Run method will redefine the clientReady channel once it
|
||||||
|
// returns, we need copy the pointer to the current clientReady channel
|
||||||
|
// to ensure that we're waiting for the correct channel, and to avoid
|
||||||
|
// a data race.
|
||||||
|
s.mu.Lock()
|
||||||
|
currentClientReady := s.clientReady
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-s.clientConnected:
|
case <-currentClientReady:
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
@@ -451,9 +480,6 @@ func (s *SignCoordinator) waitUntilConnectedWithTimeout(
|
|||||||
|
|
||||||
case <-time.After(timeout):
|
case <-time.After(timeout):
|
||||||
return ErrConnectTimeout
|
return ErrConnectTimeout
|
||||||
|
|
||||||
case <-s.doneReceiving:
|
|
||||||
return ErrNotConnected
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -537,7 +563,7 @@ func (s *SignCoordinator) getResponse(requestID uint64,
|
|||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
|
|
||||||
case <-s.doneReceiving:
|
case <-s.disconnected:
|
||||||
log.Debugf("Stopped waiting for remote signer response for "+
|
log.Debugf("Stopped waiting for remote signer response for "+
|
||||||
"request ID %d as the stream has been closed",
|
"request ID %d as the stream has been closed",
|
||||||
requestID)
|
requestID)
|
||||||
@@ -854,8 +880,36 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration,
|
|||||||
|
|
||||||
log.Tracef("Request content: %v", formatSignCoordinatorMsg(&req))
|
log.Tracef("Request content: %v", formatSignCoordinatorMsg(&req))
|
||||||
|
|
||||||
|
// reprocessOnDisconnect is a helper function that will be used to
|
||||||
|
// resend the request if the remote signer disconnects, through which
|
||||||
|
// we will wait for it to reconnect within the configured timeout, and
|
||||||
|
// then resend the request.
|
||||||
|
reprocessOnDisconnect := func() (R, error) {
|
||||||
|
var newTimeout time.Duration = noTimeout
|
||||||
|
|
||||||
|
if timeout != 0 {
|
||||||
|
newTimeout = timeout - time.Since(startTime)
|
||||||
|
|
||||||
|
if time.Since(startTime) > timeout {
|
||||||
|
return zero, ErrRequestTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return processRequest[R](
|
||||||
|
s, newTimeout, generateRequest, extractResponse,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
err = s.stream.Send(&req)
|
err = s.stream.Send(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
st, isStatusError := status.FromError(err)
|
||||||
|
if isStatusError && st.Code() == codes.Unavailable {
|
||||||
|
// If the stream was closed due to the remote signer
|
||||||
|
// disconnecting, we will retry to process the request
|
||||||
|
// if the remote signer reconnects.
|
||||||
|
return reprocessOnDisconnect()
|
||||||
|
}
|
||||||
|
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -878,7 +932,12 @@ func processRequest[R comparable](s *SignCoordinator, timeout time.Duration,
|
|||||||
resp, err = s.getResponse(reqID, s.requestTimeout)
|
resp, err = s.getResponse(reqID, s.requestTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if errors.Is(err, ErrNotConnected) {
|
||||||
|
// If the remote signer disconnected while we were waiting for
|
||||||
|
// the response, we will retry to process the request if the
|
||||||
|
// remote signer reconnects.
|
||||||
|
return reprocessOnDisconnect()
|
||||||
|
} else if err != nil {
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -6,6 +6,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
|
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
|
||||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -19,6 +22,10 @@ type mockSCStream struct {
|
|||||||
// sign coordinator to the remote signer.
|
// sign coordinator to the remote signer.
|
||||||
sendChan chan *walletrpc.SignCoordinatorRequest
|
sendChan chan *walletrpc.SignCoordinatorRequest
|
||||||
|
|
||||||
|
// sendErrorChan is used to simulate requests sent over the stream from the
|
||||||
|
// sign coordinator to the remote signer.
|
||||||
|
sendErrorChan chan error
|
||||||
|
|
||||||
// recvChan is used to simulate responses sent over the stream from the
|
// recvChan is used to simulate responses sent over the stream from the
|
||||||
// remote signer to the sign coordinator.
|
// remote signer to the sign coordinator.
|
||||||
recvChan chan *walletrpc.SignCoordinatorResponse
|
recvChan chan *walletrpc.SignCoordinatorResponse
|
||||||
@@ -32,7 +39,8 @@ type mockSCStream struct {
|
|||||||
// newMockSCStream creates a new mock stream.
|
// newMockSCStream creates a new mock stream.
|
||||||
func newMockSCStream() *mockSCStream {
|
func newMockSCStream() *mockSCStream {
|
||||||
return &mockSCStream{
|
return &mockSCStream{
|
||||||
sendChan: make(chan *walletrpc.SignCoordinatorRequest, 1),
|
sendChan: make(chan *walletrpc.SignCoordinatorRequest),
|
||||||
|
sendErrorChan: make(chan error),
|
||||||
recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1),
|
recvChan: make(chan *walletrpc.SignCoordinatorResponse, 1),
|
||||||
cancelChan: make(chan struct{}),
|
cancelChan: make(chan struct{}),
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
@@ -46,8 +54,8 @@ func (ms *mockSCStream) Send(req *walletrpc.SignCoordinatorRequest) error {
|
|||||||
case ms.sendChan <- req:
|
case ms.sendChan <- req:
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case <-ms.cancelChan:
|
case err := <-ms.sendErrorChan:
|
||||||
return ErrStreamCanceled
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,8 +102,19 @@ func (ms *mockSCStream) sendResponse(resp *walletrpc.SignCoordinatorResponse) {
|
|||||||
func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream,
|
func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream,
|
||||||
chan error) {
|
chan error) {
|
||||||
|
|
||||||
stream := newMockSCStream()
|
|
||||||
coordinator := NewSignCoordinator(2*time.Second, 3*time.Second)
|
coordinator := NewSignCoordinator(2*time.Second, 3*time.Second)
|
||||||
|
stream, errChan := setupNewStream(t, coordinator)
|
||||||
|
|
||||||
|
return coordinator, stream, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupNewStream sets up a new mock stream to simulate a communication with a
|
||||||
|
// remote signer. It also simulates the handshake between the passed sign
|
||||||
|
// coordinator and the remote signer.
|
||||||
|
func setupNewStream(t *testing.T,
|
||||||
|
coordinator *SignCoordinator) (*mockSCStream, chan error) {
|
||||||
|
|
||||||
|
stream := newMockSCStream()
|
||||||
|
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -136,7 +155,7 @@ func setupSignCoordinator(t *testing.T) (*SignCoordinator, *mockSCStream,
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return coordinator, stream, errChan
|
return stream, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRequest is a helper function to get a request that has been sent from
|
// getRequest is a helper function to get a request that has been sent from
|
||||||
@@ -369,21 +388,43 @@ func TestPingTimeout(t *testing.T) {
|
|||||||
|
|
||||||
coordinator, stream, _ := setupSignCoordinator(t)
|
coordinator, stream, _ := setupSignCoordinator(t)
|
||||||
|
|
||||||
// Simulate a Ping request that times out.
|
var wg sync.WaitGroup
|
||||||
_, err := coordinator.Ping(1 * time.Second)
|
|
||||||
|
// Simulate a Ping request that is expected to time out.
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Note that the timeout is set to 1 second.
|
||||||
|
success, err := coordinator.Ping(1 * time.Second)
|
||||||
require.Equal(t, ErrRequestTimeout, err)
|
require.Equal(t, ErrRequestTimeout, err)
|
||||||
|
require.False(t, success)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Get the request sent over the mock stream.
|
||||||
|
req1, err := getRequest(stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that the request has the expected request ID and that it's a
|
||||||
|
// Ping request.
|
||||||
|
require.Equal(t, uint64(2), req1.GetRequestId())
|
||||||
|
require.True(t, req1.GetPing())
|
||||||
|
|
||||||
|
// Verify that the coordinator has correctly set up a single response
|
||||||
|
// channel for the Ping request with the specific request ID.
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 1)
|
||||||
|
_, ok := coordinator.responses.Load(uint64(2))
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// Now wait for the request to time out.
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
// Verify that the responses map is empty after the timeout.
|
// Verify that the responses map is empty after the timeout.
|
||||||
require.Equal(t, coordinator.responses.Len(), 0)
|
require.Equal(t, coordinator.responses.Len(), 0)
|
||||||
|
|
||||||
// Now let's simulate that the response is sent back after the request
|
// Now let's simulate that the response is sent back after the request
|
||||||
// has timed out.
|
// has timed out.
|
||||||
req, err := getRequest(stream)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, uint64(2), req.GetRequestId())
|
|
||||||
require.True(t, req.GetPing())
|
|
||||||
|
|
||||||
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
|
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
|
||||||
RefRequestId: 2,
|
RefRequestId: 2,
|
||||||
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
|
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
|
||||||
@@ -718,7 +759,7 @@ func TestRemoteSignerDisconnects(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
success, err := coordinator.Ping(pingTimeout)
|
success, err := coordinator.Ping(pingTimeout)
|
||||||
require.Equal(t, ErrNotConnected, err)
|
require.Equal(t, ErrConnectTimeout, err)
|
||||||
require.False(t, success)
|
require.False(t, success)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -742,7 +783,7 @@ func TestRemoteSignerDisconnects(t *testing.T) {
|
|||||||
stream.Cancel()
|
stream.Cancel()
|
||||||
|
|
||||||
// This should cause the Run function to return the error that the
|
// This should cause the Run function to return the error that the
|
||||||
// stream was canceled with.
|
// stream was canceled.
|
||||||
err = <-runErrChan
|
err = <-runErrChan
|
||||||
require.Equal(t, ErrStreamCanceled, err)
|
require.Equal(t, ErrStreamCanceled, err)
|
||||||
|
|
||||||
@@ -752,11 +793,211 @@ func TestRemoteSignerDisconnects(t *testing.T) {
|
|||||||
// Verify that the coordinator signals that it's done receiving
|
// Verify that the coordinator signals that it's done receiving
|
||||||
// responses after the stream is canceled, i.e. the StartReceiving
|
// responses after the stream is canceled, i.e. the StartReceiving
|
||||||
// function is no longer running.
|
// function is no longer running.
|
||||||
<-coordinator.doneReceiving
|
<-coordinator.disconnected
|
||||||
|
|
||||||
// Ensure that the Ping request goroutine returned before the timeout
|
// Ensure that the Ping request goroutine returned before the timeout
|
||||||
// was reached, which indicates that the request was canceled because
|
// was reached, which indicates that the request was canceled because
|
||||||
// the remote signer disconnected.
|
// the remote signer disconnected.
|
||||||
|
require.Greater(t, time.Since(startTime), pingTimeout)
|
||||||
|
require.Less(t, time.Since(startTime), pingTimeout+100*time.Millisecond)
|
||||||
|
|
||||||
|
// Verify the responses map is empty after all responses are received
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteSignerReconnectsDuringResponseWait verifies that the sign
|
||||||
|
// coordinator correctly handles the scenario where the remote signer
|
||||||
|
// disconnects while a request is being processed and then reconnects. In this
|
||||||
|
// case, the sign coordinator should establish a new stream, reprocess the
|
||||||
|
// request, and ultimately receive a response.
|
||||||
|
func TestRemoteSignerReconnectsDuringResponseWait(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
coordinator, stream, runErrChan := setupSignCoordinator(t)
|
||||||
|
|
||||||
|
pingTimeout := 3 * time.Second
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Send a Ping request with a long timeout to ensure that the request
|
||||||
|
// will not time out before the remote signer disconnects.
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
success, err := coordinator.Ping(pingTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, success)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Get the request sent over the mock stream.
|
||||||
|
req, err := getRequest(stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that the request has the expected request ID and that it's a
|
||||||
|
// Ping request.
|
||||||
|
require.Equal(t, uint64(2), req.GetRequestId())
|
||||||
|
require.True(t, req.GetPing())
|
||||||
|
|
||||||
|
// Verify that the coordinator has correctly set up a single response
|
||||||
|
// channel for the Ping request with the specific request ID.
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 1)
|
||||||
|
_, ok := coordinator.responses.Load(uint64(2))
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// Now, lets simulate that the remote signer disconnects by canceling
|
||||||
|
// the stream, while the sign coordinator is still waiting for the Pong
|
||||||
|
// response for the request it sent.
|
||||||
|
stream.Cancel()
|
||||||
|
|
||||||
|
// This should cause the Run function to return the error that the
|
||||||
|
// stream was canceled.
|
||||||
|
err = <-runErrChan
|
||||||
|
require.Equal(t, ErrStreamCanceled, err)
|
||||||
|
|
||||||
|
// Verify that the coordinator signals that it's done receiving
|
||||||
|
// responses after the stream is canceled, i.e. the StartReceiving
|
||||||
|
// function is no longer running.
|
||||||
|
<-coordinator.disconnected
|
||||||
|
|
||||||
|
// Now let's simulate that the remote signer reconnects with a new
|
||||||
|
// stream.
|
||||||
|
stream, runErrChan = setupNewStream(t, coordinator)
|
||||||
|
|
||||||
|
// This should lead to that the sign coordinator resends the Ping
|
||||||
|
// request it's needs a response for over the new stream.
|
||||||
|
req, err = getRequest(stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Note that the request ID will be 3 for the resent request, as the
|
||||||
|
// coordinator will no longer wait for the response for the request with
|
||||||
|
// request ID 2.
|
||||||
|
require.Equal(t, uint64(3), req.GetRequestId())
|
||||||
|
require.True(t, req.GetPing())
|
||||||
|
|
||||||
|
// Verify that the coordinator has correctly set up a single response
|
||||||
|
// channel for the Ping request with the specific request ID.
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 2)
|
||||||
|
_, ok = coordinator.responses.Load(uint64(3))
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// Now let's send the Pong response for the resent Ping request.
|
||||||
|
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
|
||||||
|
RefRequestId: 3,
|
||||||
|
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
|
||||||
|
Pong: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ensure that the Ping request goroutine has finished.
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Ensure that the Ping request goroutine returned before the timeout
|
||||||
|
// was reached, which indicates that the request didn't time out as
|
||||||
|
// the remote signer reconnected in time and sent a response.
|
||||||
|
require.Less(t, time.Since(startTime), pingTimeout)
|
||||||
|
|
||||||
|
// Verify the responses map is empty after all responses are received
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteSignerDisconnectsMidSend verifies that the sign coordinator
|
||||||
|
// correctly handles the scenario in which the remote signer disconnects while
|
||||||
|
// the sign coordinator is sending data over the stream (i.e., during the
|
||||||
|
// execution of the `Send` function) and then reconnects. In such a case, the
|
||||||
|
// sign coordinator should establish a new stream, reprocess the request, and
|
||||||
|
// eventually receive a response.
|
||||||
|
func TestRemoteSignerDisconnectsMidSend(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
coordinator, stream, runErrChan := setupSignCoordinator(t)
|
||||||
|
|
||||||
|
pingTimeout := 3 * time.Second
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Send a Ping request with a long timeout to ensure that the request
|
||||||
|
// will not time out before the remote signer disconnects.
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
success, err := coordinator.Ping(pingTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, success)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Just wait slightly, to ensure that the Ping requests starts getting
|
||||||
|
// processed before we simulate the remote signer disconnecting.
|
||||||
|
<-time.After(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// We simulate the remote signer disconnecting by canceling the
|
||||||
|
// stream.
|
||||||
|
stream.Cancel()
|
||||||
|
|
||||||
|
// This should cause the Run function to return the error that the
|
||||||
|
// stream was canceled with.
|
||||||
|
err := <-runErrChan
|
||||||
|
require.Equal(t, ErrStreamCanceled, err)
|
||||||
|
|
||||||
|
// Verify that the coordinator signals that it's done receiving
|
||||||
|
// responses after the stream is canceled, i.e. the StartReceiving
|
||||||
|
// function is no longer running.
|
||||||
|
<-coordinator.disconnected
|
||||||
|
|
||||||
|
// Now since the sign coordinator is still processing the requests, and
|
||||||
|
// we never extracted the request sent over the stream, the sign
|
||||||
|
// coordinator is stuck at the steam.Send function. We simulate this
|
||||||
|
// function now errors with the codes.Unavailable error, which is what
|
||||||
|
// the function would error with if the signer was disconnected during
|
||||||
|
// the send operation in a real scenario.
|
||||||
|
stream.sendErrorChan <- status.Errorf(
|
||||||
|
codes.Unavailable, "simulated unavailable error",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify that the coordinator has correctly set up a single response
|
||||||
|
// channel for the Ping request with the specific request ID.
|
||||||
|
require.Equal(t, 1, coordinator.responses.Len())
|
||||||
|
|
||||||
|
// Now let's simulate that the remote signer reconnects with a new
|
||||||
|
// stream.
|
||||||
|
stream, runErrChan = setupNewStream(t, coordinator)
|
||||||
|
|
||||||
|
// This should lead to that the sign coordinator resends the Ping
|
||||||
|
// request it's needs a response for over the new stream.
|
||||||
|
req, err := getRequest(stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Note that the request ID will be 3 for the resent request, as the
|
||||||
|
// coordinator will no longer wait for the response for the request with
|
||||||
|
// request ID 2.
|
||||||
|
require.Equal(t, uint64(3), req.GetRequestId())
|
||||||
|
require.True(t, req.GetPing())
|
||||||
|
|
||||||
|
// Verify that the coordinator has correctly set up a single response
|
||||||
|
// channel for the Ping request with the specific request ID.
|
||||||
|
require.Equal(t, coordinator.responses.Len(), 2)
|
||||||
|
_, ok := coordinator.responses.Load(uint64(3))
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// Now let's send the Pong response for the resent Ping request.
|
||||||
|
stream.sendResponse(&walletrpc.SignCoordinatorResponse{
|
||||||
|
RefRequestId: 3,
|
||||||
|
SignResponseType: &walletrpc.SignCoordinatorResponse_Pong{
|
||||||
|
Pong: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ensure that the Ping request goroutine has finished.
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Ensure that the Ping request goroutine returned before the timeout
|
||||||
|
// was reached, which indicates that the request didn't time out as
|
||||||
|
// the remote signer reconnected in time and sent a response.
|
||||||
require.Less(t, time.Since(startTime), pingTimeout)
|
require.Less(t, time.Since(startTime), pingTimeout)
|
||||||
|
|
||||||
// Verify the responses map is empty after all responses are received
|
// Verify the responses map is empty after all responses are received
|
||||||
|
Reference in New Issue
Block a user