routing: check pubkey when applying channel update

This commit is contained in:
eugene 2022-09-12 16:15:52 -04:00 committed by yyforyongyu
parent 14f45d0722
commit 64b608bce0
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 51 additions and 30 deletions

View File

@ -932,7 +932,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route,
} }
// Apply channel update to the channel edge policy in our db. // Apply channel update to the channel edge policy in our db.
if !p.router.applyChannelUpdate(update, errSource) { if !p.router.applyChannelUpdate(update) {
log.Debugf("Invalid channel update received: node=%v", log.Debugf("Invalid channel update received: node=%v",
errVertex) errVertex)
} }

View File

@ -2393,16 +2393,32 @@ func (r *ChannelRouter) extractChannelUpdate(
// applyChannelUpdate validates a channel update and if valid, applies it to the // applyChannelUpdate validates a channel update and if valid, applies it to the
// database. It returns a bool indicating whether the updates were successful. // database. It returns a bool indicating whether the updates were successful.
func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate, func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate) bool {
pubKey *btcec.PublicKey) bool {
ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) ch, _, _, err := r.GetChannelByID(msg.ShortChannelID)
if err != nil { if err != nil {
log.Errorf("Unable to retrieve channel by id: %v", err) log.Errorf("Unable to retrieve channel by id: %v", err)
return false return false
} }
if err := ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg); err != nil { var pubKey *btcec.PublicKey
switch msg.ChannelFlags & lnwire.ChanUpdateDirection {
case 0:
pubKey, _ = ch.NodeKey1()
case 1:
pubKey, _ = ch.NodeKey2()
}
// Exit early if the pubkey cannot be decided.
if pubKey == nil {
log.Errorf("Unable to decide pubkey with ChannelFlags=%v",
msg.ChannelFlags)
return false
}
err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg)
if err != nil {
log.Errorf("Unable to validate channel update: %v", err) log.Errorf("Unable to validate channel update: %v", err)
return false return false
} }

View File

@ -385,7 +385,9 @@ func TestChannelUpdateValidation(t *testing.T) {
}, 2), }, 2),
} }
testGraph, err := createTestGraphFromChannels(t, true, testChannels, "a") testGraph, err := createTestGraphFromChannels(
t, true, testChannels, "a",
)
require.NoError(t, err, "unable to create graph") require.NoError(t, err, "unable to create graph")
const startingBlockHeight = 101 const startingBlockHeight = 101
@ -394,13 +396,13 @@ func TestChannelUpdateValidation(t *testing.T) {
) )
// Assert that the initially configured fee is retrieved correctly. // Assert that the initially configured fee is retrieved correctly.
_, policy, _, err := ctx.router.GetChannelByID( _, e1, e2, err := ctx.router.GetChannelByID(
lnwire.NewShortChanIDFromInt(1)) lnwire.NewShortChanIDFromInt(1),
)
require.NoError(t, err, "cannot retrieve channel") require.NoError(t, err, "cannot retrieve channel")
require.Equal(t, require.Equal(t, feeRate, e1.FeeProportionalMillionths, "invalid fee")
feeRate, policy.FeeProportionalMillionths, "invalid fee", require.Equal(t, feeRate, e2.FeeProportionalMillionths, "invalid fee")
)
// Setup a route from source a to destination c. The route will be used // Setup a route from source a to destination c. The route will be used
// in a call to SendToRoute. SendToRoute also applies channel updates, // in a call to SendToRoute. SendToRoute also applies channel updates,
@ -430,10 +432,13 @@ func TestChannelUpdateValidation(t *testing.T) {
// returned to the sender. // returned to the sender.
var invalidSignature [64]byte var invalidSignature [64]byte
errChanUpdate := lnwire.ChannelUpdate{ errChanUpdate := lnwire.ChannelUpdate{
Signature: invalidSignature, Signature: invalidSignature,
FeeRate: 500, FeeRate: 500,
ShortChannelID: lnwire.NewShortChanIDFromInt(1), ShortChannelID: lnwire.NewShortChanIDFromInt(1),
Timestamp: uint32(testTime.Add(time.Minute).Unix()), Timestamp: uint32(testTime.Add(time.Minute).Unix()),
MessageFlags: e2.MessageFlags,
ChannelFlags: e2.ChannelFlags,
HtlcMaximumMsat: e2.MaxHTLC,
} }
// We'll modify the SendToSwitch method so that it simulates a failed // We'll modify the SendToSwitch method so that it simulates a failed
@ -459,34 +464,34 @@ func TestChannelUpdateValidation(t *testing.T) {
_, err = ctx.router.SendToRoute(payment, rt) _, err = ctx.router.SendToRoute(payment, rt)
require.Error(t, err, "expected route to fail with channel update") require.Error(t, err, "expected route to fail with channel update")
_, policy, _, err = ctx.router.GetChannelByID( _, e1, e2, err = ctx.router.GetChannelByID(
lnwire.NewShortChanIDFromInt(1)) lnwire.NewShortChanIDFromInt(1),
)
require.NoError(t, err, "cannot retrieve channel") require.NoError(t, err, "cannot retrieve channel")
require.Equal(t, require.Equal(t, feeRate, e1.FeeProportionalMillionths,
feeRate, policy.FeeProportionalMillionths, "fee updated without valid signature")
"fee updated without valid signature", require.Equal(t, feeRate, e2.FeeProportionalMillionths,
) "fee updated without valid signature")
// Next, add a signature to the channel update. // Next, add a signature to the channel update.
signErrChanUpdate(t, testGraph.privKeyMap["b"], &errChanUpdate) signErrChanUpdate(t, testGraph.privKeyMap["b"], &errChanUpdate)
// Retry the payment using the same route as before. // Retry the payment using the same route as before.
_, err = ctx.router.SendToRoute(payment, rt) _, err = ctx.router.SendToRoute(payment, rt)
if err == nil { require.Error(t, err, "expected route to fail with channel update")
t.Fatalf("expected route to fail with channel update")
}
// This time a valid signature was supplied and the policy change should // This time a valid signature was supplied and the policy change should
// have been applied to the graph. // have been applied to the graph.
_, policy, _, err = ctx.router.GetChannelByID( _, e1, e2, err = ctx.router.GetChannelByID(
lnwire.NewShortChanIDFromInt(1)) lnwire.NewShortChanIDFromInt(1),
)
require.NoError(t, err, "cannot retrieve channel") require.NoError(t, err, "cannot retrieve channel")
require.Equal(t, require.Equal(t, feeRate, e1.FeeProportionalMillionths,
lnwire.MilliSatoshi(500), policy.FeeProportionalMillionths, "fee should not be updated")
"fee not updated even though signature is valid", require.EqualValues(t, 500, int(e2.FeeProportionalMillionths),
) "fee not updated even though signature is valid")
} }
// TestSendPaymentErrorRepeatedFeeInsufficient tests that if we receive // TestSendPaymentErrorRepeatedFeeInsufficient tests that if we receive