lntemp+itest: refactor testRPCMiddlewareInterceptor

This commit is contained in:
yyforyongyu
2023-01-16 07:49:16 +08:00
parent ad77a45112
commit 84278d6a49
5 changed files with 195 additions and 131 deletions

View File

@@ -611,3 +611,19 @@ func (h *HarnessRPC) ForwardingHistory(
return resp return resp
} }
type MiddlewareClient lnrpc.Lightning_RegisterRPCMiddlewareClient
// RegisterRPCMiddleware makes a RPC call to the node's RegisterRPCMiddleware
// and asserts. It also returns a cancel context which can cancel the context
// used by the client.
func (h *HarnessRPC) RegisterRPCMiddleware() (MiddlewareClient,
context.CancelFunc) {
ctxt, cancel := context.WithCancel(h.runCtx)
stream, err := h.LN.RegisterRPCMiddleware(ctxt)
h.NoError(err, "RegisterRPCMiddleware")
return stream, cancel
}

View File

@@ -341,4 +341,8 @@ var allTestCasesTemp = []*lntemp.TestCase{
Name: "route fee cutoff", Name: "route fee cutoff",
TestFunc: testRouteFeeCutoff, TestFunc: testRouteFeeCutoff,
}, },
{
Name: "rpc middleware interceptor",
TestFunc: testRPCMiddlewareInterceptor,
},
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntemp/node"
"github.com/lightningnetwork/lnd/lntest" "github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/macaroons"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -58,7 +59,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
[]byte("dummy_root_key"), []byte("0"), "itest", []byte("dummy_root_key"), []byte("0"), "itest",
macaroon.LatestVersion, macaroon.LatestVersion,
) )
cleanup, client := macaroonClient( cleanup, client := macaroonClientOld(
t, testNode, invalidMac, t, testNode, invalidMac,
) )
defer cleanup() defer cleanup()
@@ -75,7 +76,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
testNode.ReadMacPath(), defaultTimeout, testNode.ReadMacPath(), defaultTimeout,
) )
require.NoError(t, err) require.NoError(t, err)
cleanup, client := macaroonClient( cleanup, client := macaroonClientOld(
t, testNode, readonlyMac, t, testNode, readonlyMac,
) )
defer cleanup() defer cleanup()
@@ -96,7 +97,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
readonlyMac, macaroons.TimeoutConstraint(-30), readonlyMac, macaroons.TimeoutConstraint(-30),
) )
require.NoError(t, err) require.NoError(t, err)
cleanup, client := macaroonClient( cleanup, client := macaroonClientOld(
t, testNode, timeoutMac, t, testNode, timeoutMac,
) )
defer cleanup() defer cleanup()
@@ -118,7 +119,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
), ),
) )
require.NoError(t, err) require.NoError(t, err)
cleanup, client := macaroonClient( cleanup, client := macaroonClientOld(
t, testNode, invalidIPAddrMac, t, testNode, invalidIPAddrMac,
) )
defer cleanup() defer cleanup()
@@ -141,7 +142,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
macaroons.IPLockConstraint("127.0.0.1"), macaroons.IPLockConstraint("127.0.0.1"),
) )
require.NoError(t, err) require.NoError(t, err)
cleanup, client := macaroonClient(t, testNode, adminMac) cleanup, client := macaroonClientOld(t, testNode, adminMac)
defer cleanup() defer cleanup()
res, err := client.NewAddress(ctxt, newAddrReq) res, err := client.NewAddress(ctxt, newAddrReq)
require.NoError(t, err, "get new address") require.NoError(t, err, "get new address")
@@ -174,7 +175,7 @@ func testMacaroonAuthentication(net *lntest.NetworkHarness, ht *harnessTest) {
customMac := &macaroon.Macaroon{} customMac := &macaroon.Macaroon{}
err = customMac.UnmarshalBinary(customMacBytes) err = customMac.UnmarshalBinary(customMacBytes)
require.NoError(t, err) require.NoError(t, err)
cleanup, client := macaroonClient( cleanup, client := macaroonClientOld(
t, testNode, customMac, t, testNode, customMac,
) )
defer cleanup() defer cleanup()
@@ -426,7 +427,7 @@ func testBakeMacaroon(net *lntest.NetworkHarness, t *harnessTest) {
newMac, err := readMacaroonFromHex(bakeResp.Macaroon) newMac, err := readMacaroonFromHex(bakeResp.Macaroon)
require.NoError(t, err) require.NoError(t, err)
cleanup, readOnlyClient := macaroonClient( cleanup, readOnlyClient := macaroonClientOld(
t, testNode, newMac, t, testNode, newMac,
) )
defer cleanup() defer cleanup()
@@ -505,7 +506,7 @@ func testBakeMacaroon(net *lntest.NetworkHarness, t *harnessTest) {
testNode.AdminMacPath(), defaultTimeout, testNode.AdminMacPath(), defaultTimeout,
) )
require.NoError(tt, err) require.NoError(tt, err)
cleanup, client := macaroonClient(tt, testNode, adminMac) cleanup, client := macaroonClientOld(tt, testNode, adminMac)
defer cleanup() defer cleanup()
tc.run(ctxt, tt, client) tc.run(ctxt, tt, client)
@@ -530,7 +531,7 @@ func testDeleteMacaroonID(net *lntest.NetworkHarness, t *harnessTest) {
testNode.AdminMacPath(), defaultTimeout, testNode.AdminMacPath(), defaultTimeout,
) )
require.NoError(t.t, err) require.NoError(t.t, err)
cleanup, client := macaroonClient(t.t, testNode, adminMac) cleanup, client := macaroonClientOld(t.t, testNode, adminMac)
defer cleanup() defer cleanup()
// Record the number of macaroon IDs before creation. // Record the number of macaroon IDs before creation.
@@ -595,7 +596,7 @@ func testDeleteMacaroonID(net *lntest.NetworkHarness, t *harnessTest) {
// Check that the deleted macaroon can no longer access macaroon:read. // Check that the deleted macaroon can no longer access macaroon:read.
deletedMac, err := readMacaroonFromHex(macList[0]) deletedMac, err := readMacaroonFromHex(macList[0])
require.NoError(t.t, err) require.NoError(t.t, err)
cleanup, client = macaroonClient(t.t, testNode, deletedMac) cleanup, client = macaroonClientOld(t.t, testNode, deletedMac)
defer cleanup() defer cleanup()
// Because the macaroon is deleted, it will be treated as an invalid one. // Because the macaroon is deleted, it will be treated as an invalid one.
@@ -717,7 +718,8 @@ func readMacaroonFromHex(macHex string) (*macaroon.Macaroon, error) {
return mac, nil return mac, nil
} }
func macaroonClient(t *testing.T, testNode *lntest.HarnessNode, // TODO(yy): remove.
func macaroonClientOld(t *testing.T, testNode *lntest.HarnessNode,
mac *macaroon.Macaroon) (func(), lnrpc.LightningClient) { mac *macaroon.Macaroon) (func(), lnrpc.LightningClient) {
conn, err := testNode.ConnectRPCWithMacaroon(mac) conn, err := testNode.ConnectRPCWithMacaroon(mac)
@@ -729,3 +731,18 @@ func macaroonClient(t *testing.T, testNode *lntest.HarnessNode,
} }
return cleanup, lnrpc.NewLightningClient(conn) return cleanup, lnrpc.NewLightningClient(conn)
} }
func macaroonClient(t *testing.T, testNode *node.HarnessNode,
mac *macaroon.Macaroon) (func(), lnrpc.LightningClient) {
t.Helper()
conn, err := testNode.ConnectRPCWithMacaroon(mac)
require.NoError(t, err, "connect to alice")
cleanup := func() {
err := conn.Close()
require.NoError(t, err, "close")
}
return cleanup, lnrpc.NewLightningClient(conn)
}

View File

@@ -8,7 +8,8 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntest" "github.com/lightningnetwork/lnd/lntemp"
"github.com/lightningnetwork/lnd/lntemp/node"
"github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -18,133 +19,149 @@ import (
// testRPCMiddlewareInterceptor tests that the RPC middleware interceptor can // testRPCMiddlewareInterceptor tests that the RPC middleware interceptor can
// be used correctly and in a safe way. // be used correctly and in a safe way.
func testRPCMiddlewareInterceptor(net *lntest.NetworkHarness, t *harnessTest) { func testRPCMiddlewareInterceptor(ht *lntemp.HarnessTest) {
// Let's first enable the middleware interceptor. // Let's first enable the middleware interceptor.
net.Alice.Cfg.ExtraArgs = append( //
net.Alice.Cfg.ExtraArgs, "--rpcmiddleware.enable", // NOTE: we cannot use standby nodes here as the test messes with
) // middleware interceptor. Thus we also skip the calling of cleanup of
err := net.RestartNode(net.Alice, nil) // each of the following subtests because no standby nodes are used.
require.NoError(t.t, err) alice := ht.NewNode("alice", []string{"--rpcmiddleware.enable"})
bob := ht.NewNode("bob", nil)
// Let's set up a channel between Alice and Bob, just to get some useful // Let's set up a channel between Alice and Bob, just to get some useful
// data to inspect when doing RPC calls to Alice later. // data to inspect when doing RPC calls to Alice later.
net.EnsureConnected(t.t, net.Alice, net.Bob) ht.EnsureConnected(alice, bob)
net.SendCoins(t.t, btcutil.SatoshiPerBitcoin, net.Alice) ht.FundCoins(btcutil.SatoshiPerBitcoin, alice)
_ = openChannelAndAssert( ht.OpenChannel(alice, bob, lntemp.OpenChannelParams{Amt: 1_234_567})
t, net, net.Alice, net.Bob, lntest.OpenChannelParams{
Amt: 1_234_567,
},
)
// Load or bake the macaroons that the simulated users will use to // Load or bake the macaroons that the simulated users will use to
// access the RPC. // access the RPC.
readonlyMac, err := net.Alice.ReadMacaroon( readonlyMac, err := alice.ReadMacaroon(
net.Alice.ReadMacPath(), defaultTimeout, alice.Cfg.ReadMacPath, defaultTimeout,
) )
require.NoError(t.t, err) require.NoError(ht, err)
adminMac, err := net.Alice.ReadMacaroon( adminMac, err := alice.ReadMacaroon(
net.Alice.AdminMacPath(), defaultTimeout, alice.Cfg.AdminMacPath, defaultTimeout,
) )
require.NoError(t.t, err) require.NoError(ht, err)
customCaveatReadonlyMac, err := macaroons.SafeCopyMacaroon(readonlyMac) customCaveatReadonlyMac, err := macaroons.SafeCopyMacaroon(readonlyMac)
require.NoError(t.t, err) require.NoError(ht, err)
addConstraint := macaroons.CustomConstraint( addConstraint := macaroons.CustomConstraint(
"itest-caveat", "itest-value", "itest-caveat", "itest-value",
) )
require.NoError(t.t, addConstraint(customCaveatReadonlyMac)) require.NoError(ht, addConstraint(customCaveatReadonlyMac))
customCaveatAdminMac, err := macaroons.SafeCopyMacaroon(adminMac) customCaveatAdminMac, err := macaroons.SafeCopyMacaroon(adminMac)
require.NoError(t.t, err) require.NoError(ht, err)
require.NoError(t.t, addConstraint(customCaveatAdminMac)) require.NoError(ht, addConstraint(customCaveatAdminMac))
// Run all sub-tests now. We can't run anything in parallel because that // Run all sub-tests now. We can't run anything in parallel because that
// would cause the main test function to exit and the nodes being // would cause the main test function to exit and the nodes being
// cleaned up. // cleaned up.
t.t.Run("registration restrictions", func(tt *testing.T) { ht.Run("registration restrictions", func(tt *testing.T) {
middlewareRegistrationRestrictionTests(tt, net.Alice) middlewareRegistrationRestrictionTests(tt, alice)
}) })
t.t.Run("read-only intercept", func(tt *testing.T) {
ht.Run("read-only intercept", func(tt *testing.T) {
registration := registerMiddleware( registration := registerMiddleware(
tt, net.Alice, &lnrpc.MiddlewareRegistration{ tt, alice, &lnrpc.MiddlewareRegistration{
MiddlewareName: "itest-interceptor", MiddlewareName: "itest-interceptor-1",
ReadOnlyMode: true, ReadOnlyMode: true,
}, true, }, true,
) )
defer registration.cancel() defer registration.cancel()
middlewareInterceptionTest( middlewareInterceptionTest(
tt, net.Alice, net.Bob, registration, readonlyMac, tt, alice, bob, registration, readonlyMac,
customCaveatReadonlyMac, true, customCaveatReadonlyMac, true,
) )
}) })
// We've manually disconnected Bob from Alice in the previous test, make // We've manually disconnected Bob from Alice in the previous test, make
// sure they're connected again. // sure they're connected again.
net.EnsureConnected(t.t, net.Alice, net.Bob) //
t.t.Run("encumbered macaroon intercept", func(tt *testing.T) { // NOTE: we may get an error here saying "interceptor RPC client quit"
// as it takes some time for the interceptor to fully quit. Thus we
// restart the node here to make sure the old interceptor is removed
// from registration.
ht.RestartNode(alice)
ht.EnsureConnected(alice, bob)
ht.Run("encumbered macaroon intercept", func(tt *testing.T) {
registration := registerMiddleware( registration := registerMiddleware(
tt, net.Alice, &lnrpc.MiddlewareRegistration{ tt, alice, &lnrpc.MiddlewareRegistration{
MiddlewareName: "itest-interceptor", MiddlewareName: "itest-interceptor-2",
CustomMacaroonCaveatName: "itest-caveat", CustomMacaroonCaveatName: "itest-caveat",
}, true, }, true,
) )
defer registration.cancel() defer registration.cancel()
middlewareInterceptionTest( middlewareInterceptionTest(
tt, net.Alice, net.Bob, registration, tt, alice, bob, registration,
customCaveatReadonlyMac, readonlyMac, false, customCaveatReadonlyMac, readonlyMac, false,
) )
}) })
// Next, run the response manipulation tests. // Next, run the response manipulation tests.
net.EnsureConnected(t.t, net.Alice, net.Bob) //
t.t.Run("read-only not allowed to manipulate", func(tt *testing.T) { // NOTE: we may get an error here saying "interceptor RPC client quit"
// as it takes some time for the interceptor to fully quit. Thus we
// restart the node here to make sure the old interceptor is removed
// from registration.
ht.RestartNode(alice)
ht.EnsureConnected(alice, bob)
ht.Run("read-only not allowed to manipulate", func(tt *testing.T) {
registration := registerMiddleware( registration := registerMiddleware(
tt, net.Alice, &lnrpc.MiddlewareRegistration{ tt, alice, &lnrpc.MiddlewareRegistration{
MiddlewareName: "itest-interceptor", MiddlewareName: "itest-interceptor-3",
ReadOnlyMode: true, ReadOnlyMode: true,
}, true, }, true,
) )
defer registration.cancel() defer registration.cancel()
middlewareRequestManipulationTest( middlewareRequestManipulationTest(
tt, net.Alice, registration, adminMac, true, tt, alice, registration, adminMac, true,
) )
middlewareResponseManipulationTest( middlewareResponseManipulationTest(
tt, net.Alice, net.Bob, registration, readonlyMac, true, tt, alice, bob, registration, readonlyMac, true,
) )
}) })
net.EnsureConnected(t.t, net.Alice, net.Bob)
t.t.Run("encumbered macaroon manipulate", func(tt *testing.T) { // NOTE: we may get an error here saying "interceptor RPC client quit"
// as it takes some time for the interceptor to fully quit. Thus we
// restart the node here to make sure the old interceptor is removed
// from registration.
ht.RestartNode(alice)
ht.EnsureConnected(alice, bob)
ht.Run("encumbered macaroon manipulate", func(tt *testing.T) {
registration := registerMiddleware( registration := registerMiddleware(
tt, net.Alice, &lnrpc.MiddlewareRegistration{ tt, alice, &lnrpc.MiddlewareRegistration{
MiddlewareName: "itest-interceptor", MiddlewareName: "itest-interceptor-4",
CustomMacaroonCaveatName: "itest-caveat", CustomMacaroonCaveatName: "itest-caveat",
}, true, }, true,
) )
defer registration.cancel() defer registration.cancel()
middlewareRequestManipulationTest( middlewareRequestManipulationTest(
tt, net.Alice, registration, customCaveatAdminMac, tt, alice, registration, customCaveatAdminMac, false,
false,
) )
middlewareResponseManipulationTest( middlewareResponseManipulationTest(
tt, net.Alice, net.Bob, registration, tt, alice, bob, registration,
customCaveatReadonlyMac, false, customCaveatReadonlyMac, false,
) )
}) })
// And finally make sure mandatory middleware is always checked for any // And finally make sure mandatory middleware is always checked for any
// RPC request. // RPC request.
t.t.Run("mandatory middleware", func(tt *testing.T) { ht.Run("mandatory middleware", func(tt *testing.T) {
middlewareMandatoryTest(tt, net.Alice, net) st := ht.Subtest(tt)
middlewareMandatoryTest(st, alice)
}) })
} }
// middlewareRegistrationRestrictionTests tests all restrictions that apply to // middlewareRegistrationRestrictionTests tests all restrictions that apply to
// registering a middleware. // registering a middleware.
func middlewareRegistrationRestrictionTests(t *testing.T, func middlewareRegistrationRestrictionTests(t *testing.T,
node *lntest.HarnessNode) { node *node.HarnessNode) {
testCases := []struct { testCases := []struct {
registration *lnrpc.MiddlewareRegistration registration *lnrpc.MiddlewareRegistration
@@ -189,10 +206,12 @@ func middlewareRegistrationRestrictionTests(t *testing.T,
// intercepted. It also makes sure that depending on the mode (read-only or // intercepted. It also makes sure that depending on the mode (read-only or
// custom macaroon caveat) a middleware only gets access to the requests it // custom macaroon caveat) a middleware only gets access to the requests it
// should be allowed access to. // should be allowed access to.
func middlewareInterceptionTest(t *testing.T, node *lntest.HarnessNode, func middlewareInterceptionTest(t *testing.T,
peer *lntest.HarnessNode, registration *middlewareHarness, node, peer *node.HarnessNode, registration *middlewareHarness,
userMac *macaroon.Macaroon, disallowedMac *macaroon.Macaroon, userMac *macaroon.Macaroon,
readOnly bool) { disallowedMac *macaroon.Macaroon, readOnly bool) {
t.Helper()
// Everything we test here should be executed in a matter of // Everything we test here should be executed in a matter of
// milliseconds, so we can use one single timeout context for all calls. // milliseconds, so we can use one single timeout context for all calls.
@@ -253,10 +272,7 @@ func middlewareInterceptionTest(t *testing.T, node *lntest.HarnessNode,
// Disconnect Bob to trigger a peer event without using Alice's RPC // Disconnect Bob to trigger a peer event without using Alice's RPC
// interface itself. // interface itself.
_, err = peer.DisconnectPeer(ctxc, &lnrpc.DisconnectPeerRequest{ peer.RPC.DisconnectPeer(node.PubKeyStr)
PubKey: node.PubKeyStr,
})
require.NoError(t, err)
peerEvent, err := resp2.Recv() peerEvent, err := resp2.Recv()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, lnrpc.PeerEvent_PEER_OFFLINE, peerEvent.GetType()) require.Equal(t, lnrpc.PeerEvent_PEER_OFFLINE, peerEvent.GetType())
@@ -330,10 +346,12 @@ func middlewareInterceptionTest(t *testing.T, node *lntest.HarnessNode,
// middlewareResponseManipulationTest tests that unary and streaming responses // middlewareResponseManipulationTest tests that unary and streaming responses
// can be intercepted and also manipulated, at least if the middleware didn't // can be intercepted and also manipulated, at least if the middleware didn't
// register for read-only access. // register for read-only access.
func middlewareResponseManipulationTest(t *testing.T, node *lntest.HarnessNode, func middlewareResponseManipulationTest(t *testing.T,
peer *lntest.HarnessNode, registration *middlewareHarness, node, peer *node.HarnessNode, registration *middlewareHarness,
userMac *macaroon.Macaroon, readOnly bool) { userMac *macaroon.Macaroon, readOnly bool) {
t.Helper()
// Everything we test here should be executed in a matter of // Everything we test here should be executed in a matter of
// milliseconds, so we can use one single timeout context for all calls. // milliseconds, so we can use one single timeout context for all calls.
ctxb := context.Background() ctxb := context.Background()
@@ -421,10 +439,7 @@ func middlewareResponseManipulationTest(t *testing.T, node *lntest.HarnessNode,
// Disconnect Bob to trigger a peer event without using Alice's RPC // Disconnect Bob to trigger a peer event without using Alice's RPC
// interface itself. // interface itself.
_, err = peer.DisconnectPeer(ctxc, &lnrpc.DisconnectPeerRequest{ peer.RPC.DisconnectPeer(node.PubKeyStr)
PubKey: node.PubKeyStr,
})
require.NoError(t, err)
peerEvent, err := resp2.Recv() peerEvent, err := resp2.Recv()
require.NoError(t, err) require.NoError(t, err)
@@ -448,10 +463,12 @@ func middlewareResponseManipulationTest(t *testing.T, node *lntest.HarnessNode,
// middlewareRequestManipulationTest tests that unary and streaming requests // middlewareRequestManipulationTest tests that unary and streaming requests
// can be intercepted and also manipulated, at least if the middleware didn't // can be intercepted and also manipulated, at least if the middleware didn't
// register for read-only access. // register for read-only access.
func middlewareRequestManipulationTest(t *testing.T, node *lntest.HarnessNode, func middlewareRequestManipulationTest(t *testing.T, node *node.HarnessNode,
registration *middlewareHarness, userMac *macaroon.Macaroon, registration *middlewareHarness, userMac *macaroon.Macaroon,
readOnly bool) { readOnly bool) {
t.Helper()
// Everything we test here should be executed in a matter of // Everything we test here should be executed in a matter of
// milliseconds, so we can use one single timeout context for all calls. // milliseconds, so we can use one single timeout context for all calls.
ctxb := context.Background() ctxb := context.Background()
@@ -528,54 +545,44 @@ func middlewareRequestManipulationTest(t *testing.T, node *lntest.HarnessNode,
// middlewareMandatoryTest tests that all RPC requests are blocked if there is // middlewareMandatoryTest tests that all RPC requests are blocked if there is
// a mandatory middleware declared that's currently not registered. // a mandatory middleware declared that's currently not registered.
func middlewareMandatoryTest(t *testing.T, node *lntest.HarnessNode, func middlewareMandatoryTest(ht *lntemp.HarnessTest, node *node.HarnessNode) {
net *lntest.NetworkHarness) {
// Let's declare our itest interceptor as mandatory but don't register // Let's declare our itest interceptor as mandatory but don't register
// it just yet. That should cause all RPC requests to fail, except for // it just yet. That should cause all RPC requests to fail, except for
// the registration itself. // the registration itself.
node.Cfg.ExtraArgs = append( node.Cfg.SkipUnlock = true
node.Cfg.ExtraArgs, ht.RestartNodeWithExtraArgs(node, []string{
"--noseedbackup", "--rpcmiddleware.enable",
"--rpcmiddleware.addmandatory=itest-interceptor", "--rpcmiddleware.addmandatory=itest-interceptor",
) })
err := net.RestartNodeNoUnlock(node, nil, false)
require.NoError(t, err)
// The "wait for node to start" flag of the above restart does too much // The "wait for node to start" flag of the above restart does too much
// and has a call to GetInfo built in, which will fail in this special // and has a call to GetInfo built in, which will fail in this special
// test case. So we need to do the wait and client setup manually here. // test case. So we need to do the wait and client setup manually here.
conn, err := node.ConnectRPC(true) conn, err := node.ConnectRPC()
require.NoError(t, err) require.NoError(ht, err)
node.InitRPCClients(conn) node.InitRPCClients(conn)
err = node.WaitUntilStateReached(lnrpc.WalletState_RPC_ACTIVE) err = node.WaitUntilServerActive()
require.NoError(t, err) require.NoError(ht, err)
node.LightningClient = lnrpc.NewLightningClient(conn)
ctxb := context.Background() ctxb := context.Background()
ctxc, cancel := context.WithTimeout(ctxb, defaultTimeout) ctxc, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel() defer cancel()
// Test a unary request first. // Test a unary request first.
_, err = node.ListChannels(ctxc, &lnrpc.ListChannelsRequest{}) _, err = node.RPC.LN.ListChannels(ctxc, &lnrpc.ListChannelsRequest{})
require.Error(t, err) require.Contains(ht, err.Error(), "middleware 'itest-interceptor' is "+
require.Contains( "currently not registered")
t, err.Error(), "middleware 'itest-interceptor' is "+
"currently not registered",
)
// Then a streaming one. // Then a streaming one.
stream, err := node.SubscribeInvoices(ctxc, &lnrpc.InvoiceSubscription{}) stream := node.RPC.SubscribeInvoices(&lnrpc.InvoiceSubscription{})
require.NoError(t, err)
_, err = stream.Recv() _, err = stream.Recv()
require.Error(t, err) require.Error(ht, err)
require.Contains( require.Contains(ht, err.Error(), "middleware 'itest-interceptor' is "+
t, err.Error(), "middleware 'itest-interceptor' is "+ "currently not registered")
"currently not registered",
)
// Now let's register the middleware and try again. // Now let's register the middleware and try again.
registration := registerMiddleware( registration := registerMiddleware(
t, node, &lnrpc.MiddlewareRegistration{ ht.T, node, &lnrpc.MiddlewareRegistration{
MiddlewareName: "itest-interceptor", MiddlewareName: "itest-interceptor",
CustomMacaroonCaveatName: "itest-caveat", CustomMacaroonCaveatName: "itest-caveat",
}, true, }, true,
@@ -584,16 +591,13 @@ func middlewareMandatoryTest(t *testing.T, node *lntest.HarnessNode,
// Both the unary and streaming requests should now be allowed. // Both the unary and streaming requests should now be allowed.
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
_, err = node.ListChannels(ctxc, &lnrpc.ListChannelsRequest{}) node.RPC.ListChannels(&lnrpc.ListChannelsRequest{})
require.NoError(t, err) node.RPC.SubscribeInvoices(&lnrpc.InvoiceSubscription{})
_, err = node.SubscribeInvoices(ctxc, &lnrpc.InvoiceSubscription{})
require.NoError(t, err)
// We now shut down the node manually to prevent the test from failing // We now shut down the node manually to prevent the test from failing
// because we can't call the stop RPC if we unregister the middleware in // because we can't call the stop RPC if we unregister the middleware
// the defer statement above. // in the defer statement above.
err = net.ShutdownNode(node) ht.KillNode(node)
require.NoError(t, err)
} }
// assertInterceptedType makes sure that the intercept message sent by the RPC // assertInterceptedType makes sure that the intercept message sent by the RPC
@@ -648,35 +652,62 @@ type middlewareHarness struct {
// registerMiddleware creates a new middleware harness and sends the initial // registerMiddleware creates a new middleware harness and sends the initial
// register message to the RPC server. // register message to the RPC server.
func registerMiddleware(t *testing.T, node *lntest.HarnessNode, func registerMiddleware(t *testing.T, node *node.HarnessNode,
registration *lnrpc.MiddlewareRegistration, registration *lnrpc.MiddlewareRegistration,
waitForRegister bool) *middlewareHarness { waitForRegister bool) *middlewareHarness {
ctxc, cancel := context.WithCancel(context.Background()) t.Helper()
middlewareStream, err := node.RegisterRPCMiddleware(ctxc) middlewareStream, cancel := node.RPC.RegisterRPCMiddleware()
require.NoError(t, err)
err = middlewareStream.Send(&lnrpc.RPCMiddlewareResponse{ errChan := make(chan error)
MiddlewareMessage: &lnrpc.RPCMiddlewareResponse_Register{ go func() {
msg := &lnrpc.RPCMiddlewareResponse_Register{
Register: registration, Register: registration,
}, }
}) err := middlewareStream.Send(&lnrpc.RPCMiddlewareResponse{
require.NoError(t, err) MiddlewareMessage: msg,
})
if waitForRegister { errChan <- err
// Wait for the registration complete message. }()
regCompleteMsg, err := middlewareStream.Recv()
require.NoError(t, err) select {
require.True(t, regCompleteMsg.GetRegComplete()) case <-time.After(defaultTimeout):
require.Fail(t, "registerMiddleware send timeout")
case err := <-errChan:
require.NoError(t, err, "registerMiddleware send failed")
} }
return &middlewareHarness{ mh := &middlewareHarness{
t: t, t: t,
cancel: cancel, cancel: cancel,
stream: middlewareStream, stream: middlewareStream,
responsesChan: make(chan *lnrpc.RPCMessage), responsesChan: make(chan *lnrpc.RPCMessage),
} }
if !waitForRegister {
return mh
}
// Wait for the registration complete message.
msg := make(chan *lnrpc.RPCMiddlewareRequest)
go func() {
regCompleteMsg, err := middlewareStream.Recv()
require.NoError(t, err, "registerMiddleware recv failed")
msg <- regCompleteMsg
}()
select {
case <-time.After(defaultTimeout):
require.Fail(t, "registerMiddleware recv timeout")
case m := <-msg:
require.True(t, m.GetRegComplete())
}
return mh
} }
// interceptUnary intercepts a unary call, optionally requesting to replace the // interceptUnary intercepts a unary call, optionally requesting to replace the

View File

@@ -116,10 +116,6 @@ var allTestCases = []*testCase{
name: "wallet import pubkey", name: "wallet import pubkey",
test: testWalletImportPubKey, test: testWalletImportPubKey,
}, },
{
name: "rpc middleware interceptor",
test: testRPCMiddlewareInterceptor,
},
{ {
name: "wipe forwarding packages", name: "wipe forwarding packages",
test: testWipeForwardingPackages, test: testWipeForwardingPackages,