diff --git a/lntest/itest/lnd_rest_api_test.go b/lntest/itest/lnd_rest_api_test.go index eac605800..290b98171 100644 --- a/lntest/itest/lnd_rest_api_test.go +++ b/lntest/itest/lnd_rest_api_test.go @@ -438,8 +438,8 @@ func wsTestCaseBiDirectionalSubscription(ht *harnessTest, require.Nil(ht.t, err, "websocket") defer func() { err := conn.WriteMessage(websocket.CloseMessage, closeMsg) - require.NoError(ht.t, err) _ = conn.Close() + require.NoError(ht.t, err) }() // Buffer the message channel to make sure we're always blocking on @@ -576,21 +576,27 @@ func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { require.Nil(ht.t, err, "websocket") defer func() { err := conn.WriteMessage(websocket.CloseMessage, closeMsg) - require.NoError(ht.t, err) _ = conn.Close() + require.NoError(ht.t, err) }() // We want to be able to read invoices for a long time, making sure we // can continue to read even after we've gone through several ping/pong // cycles. invoices := make(chan *lnrpc.Invoice, 1) - errors := make(chan error) + errChan := make(chan error) done := make(chan struct{}) + timeout := time.After(defaultTimeout) + + defer close(done) go func() { for { _, msg, err := conn.ReadMessage() if err != nil { - errors <- err + select { + case errChan <- err: + case <-done: + } return } @@ -599,7 +605,11 @@ func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { // get rid of here. msgStr := string(msg) if !strings.Contains(msgStr, "\"result\":") { - errors <- fmt.Errorf("invalid msg: %s", msgStr) + select { + case errChan <- fmt.Errorf("invalid msg: %s", + msgStr): + case <-done: + } return } msgStr = resultPattern.ReplaceAllString(msgStr, "${1}") @@ -609,7 +619,10 @@ func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { protoMsg := &lnrpc.Invoice{} err = jsonpb.UnmarshalString(msgStr, protoMsg) if err != nil { - errors <- err + select { + case errChan <- err: + case <-done: + } return } @@ -643,8 +656,11 @@ func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { require.Equal(ht.t, int64(value), streamMsg.Value) require.Equal(ht.t, memo, streamMsg.Memo) - case err := <-errors: + case err := <-errChan: require.Fail(ht.t, "Error reading invoice: %v", err) + + case <-timeout: + require.Fail(ht.t, "No invoice msg received in time") } // Let's wait for at least a whole ping/pong cycle to happen, so @@ -652,7 +668,6 @@ func wsTestPingPongTimeout(ht *harnessTest, net *lntest.NetworkHarness) { // We double the pong wait just to add some extra margin. time.Sleep(pingInterval + 2*pongWait) } - close(done) } // invokeGET calls the given URL with the GET method and appropriate macaroon