lnwallet/chancloser: enforce pubkey binding for msg mapper

This commit is contained in:
Olaoluwa Osuntokun
2025-02-07 18:10:45 -08:00
parent ed8a672bd3
commit b2794b07cb
4 changed files with 22 additions and 13 deletions

View File

@@ -1,8 +1,10 @@
package chancloser package chancloser
import ( import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/msgmux"
) )
// RbfMsgMapper is a struct that implements the MsgMapper interface for the // RbfMsgMapper is a struct that implements the MsgMapper interface for the
@@ -16,16 +18,21 @@ type RbfMsgMapper struct {
// chanID is the channel ID of the channel being closed. // chanID is the channel ID of the channel being closed.
chanID lnwire.ChannelID chanID lnwire.ChannelID
// peerPub is the public key of the peer that the channel is being
// closed.
peerPub btcec.PublicKey
} }
// NewRbfMsgMapper creates a new RbfMsgMapper instance given the current block // NewRbfMsgMapper creates a new RbfMsgMapper instance given the current block
// height when the co-op close request was initiated. // height when the co-op close request was initiated.
func NewRbfMsgMapper(blockHeight uint32, func NewRbfMsgMapper(blockHeight uint32,
chanID lnwire.ChannelID) *RbfMsgMapper { chanID lnwire.ChannelID, peerPub btcec.PublicKey) *RbfMsgMapper {
return &RbfMsgMapper{ return &RbfMsgMapper{
blockHeight: blockHeight, blockHeight: blockHeight,
chanID: chanID, chanID: chanID,
peerPub: peerPub,
} }
} }
@@ -34,18 +41,20 @@ func someEvent[T ProtocolEvent](m T) fn.Option[ProtocolEvent] {
return fn.Some(ProtocolEvent(m)) return fn.Some(ProtocolEvent(m))
} }
// isExpectedChanID returns true if the channel ID of the message matches the // isForUs returns true if the channel ID + pubkey of the message matches the
// bound instance. // bound instance.
func (r *RbfMsgMapper) isExpectedChanID(chanID lnwire.ChannelID) bool { func (r *RbfMsgMapper) isForUs(chanID lnwire.ChannelID,
return r.chanID == chanID fromPub btcec.PublicKey) bool {
return r.chanID == chanID && r.peerPub.IsEqual(&fromPub)
} }
// MapMsg maps a wire message into a FSM event. If the message is not mappable, // MapMsg maps a wire message into a FSM event. If the message is not mappable,
// then an error is returned. // then an error is returned.
func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] { func (r *RbfMsgMapper) MapMsg(wireMsg msgmux.PeerMsg) fn.Option[ProtocolEvent] {
switch msg := wireMsg.(type) { switch msg := wireMsg.Message.(type) {
case *lnwire.Shutdown: case *lnwire.Shutdown:
if !r.isExpectedChanID(msg.ChannelID) { if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) {
return fn.None[ProtocolEvent]() return fn.None[ProtocolEvent]()
} }
@@ -55,7 +64,7 @@ func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] {
}) })
case *lnwire.ClosingComplete: case *lnwire.ClosingComplete:
if !r.isExpectedChanID(msg.ChannelID) { if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) {
return fn.None[ProtocolEvent]() return fn.None[ProtocolEvent]()
} }
@@ -64,7 +73,7 @@ func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] {
}) })
case *lnwire.ClosingSig: case *lnwire.ClosingSig:
if !r.isExpectedChanID(msg.ChannelID) { if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) {
return fn.None[ProtocolEvent]() return fn.None[ProtocolEvent]()
} }

View File

@@ -687,7 +687,7 @@ func newRbfCloserTestHarness(t *testing.T,
peerPub := randPubKey(t) peerPub := randPubKey(t)
msgMapper := NewRbfMsgMapper(uint32(startingHeight), chanID) msgMapper := NewRbfMsgMapper(uint32(startingHeight), chanID, *peerPub)
initialState := cfg.initialState.UnwrapOr(&ChannelActive{}) initialState := cfg.initialState.UnwrapOr(&ChannelActive{})

View File

@@ -2,7 +2,7 @@ package protofsm
import ( import (
"github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/msgmux"
) )
// MsgMapper is used to map incoming wire messages into a FSM event. This is // MsgMapper is used to map incoming wire messages into a FSM event. This is
@@ -11,5 +11,5 @@ import (
type MsgMapper[Event any] interface { type MsgMapper[Event any] interface {
// MapMsg maps a wire message into a FSM event. If the message is not // MapMsg maps a wire message into a FSM event. If the message is not
// mappable, then an None is returned. // mappable, then an None is returned.
MapMsg(msg lnwire.Message) fn.Option[Event] MapMsg(msg msgmux.PeerMsg) fn.Option[Event]
} }

View File

@@ -406,7 +406,7 @@ type dummyMsgMapper struct {
mock.Mock mock.Mock
} }
func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] { func (d *dummyMsgMapper) MapMsg(wireMsg msgmux.PeerMsg) fn.Option[dummyEvents] {
args := d.Called(wireMsg) args := d.Called(wireMsg)
//nolint:forcetypeassert //nolint:forcetypeassert