routing: update to use lnwire.ChannelUpdate interface

This commit is contained in:
Elle Mouton 2023-11-07 10:20:51 +02:00
parent 0b964d8e93
commit 43de450156
No known key found for this signature in database
GPG Key ID: D7D916376026F177
8 changed files with 94 additions and 65 deletions

@ -256,3 +256,36 @@ func (c *ChannelEdgePolicy2) GetToNode() [33]byte {
// A compile-time check to ensure that ChannelEdgePolicy2 implements the
// ChannelEdgePolicy interface.
var _ ChannelEdgePolicy = (*ChannelEdgePolicy2)(nil)
// EdgePolicyFromUpdate converts the given lnwire.ChannelUpdate into the
// corresponding ChannelEdgePolicy type.
func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) (
ChannelEdgePolicy, error) {
switch upd := update.(type) {
case *lnwire.ChannelUpdate1:
//nolint:lll
return &ChannelEdgePolicy1{
SigBytes: upd.Signature.ToSignatureBytes(),
ChannelID: upd.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(upd.Timestamp), 0),
MessageFlags: upd.MessageFlags,
ChannelFlags: upd.ChannelFlags,
TimeLockDelta: upd.TimeLockDelta,
MinHTLC: upd.HtlcMinimumMsat,
MaxHTLC: upd.HtlcMaximumMsat,
FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate),
ExtraOpaqueData: upd.ExtraOpaqueData,
}, nil
case *lnwire.ChannelUpdate2:
return &ChannelEdgePolicy2{
ChannelUpdate2: *upd,
}, nil
default:
return nil, fmt.Errorf("unhandled implementation of "+
"lnwire.ChannelUpdate: %T", update)
}
}

@ -93,7 +93,7 @@ type mockGraphSource struct {
mu sync.Mutex
nodes []channeldb.LightningNode
infos map[uint64]models.ChannelEdgeInfo
edges map[uint64][]models.ChannelEdgePolicy1
edges map[uint64][]models.ChannelEdgePolicy
zombies map[uint64][][33]byte
chansToReject map[uint64]struct{}
addEdgeErrCode fn.Option[graph.ErrorCode]
@ -103,7 +103,7 @@ func newMockRouter(height uint32) *mockGraphSource {
return &mockGraphSource{
bestHeight: height,
infos: make(map[uint64]models.ChannelEdgeInfo),
edges: make(map[uint64][]models.ChannelEdgePolicy1),
edges: make(map[uint64][]models.ChannelEdgePolicy),
zombies: make(map[uint64][][33]byte),
chansToReject: make(map[uint64]struct{}),
}
@ -161,20 +161,22 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) {
r.chansToReject[chanID] = struct{}{}
}
func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy1,
func (r *mockGraphSource) UpdateEdge(edge models.ChannelEdgePolicy,
_ ...batch.SchedulerOption) error {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.edges[edge.ChannelID]) == 0 {
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy1, 2)
chanID := edge.SCID().ToUint64()
if len(r.edges[chanID]) == 0 {
r.edges[chanID] = make([]models.ChannelEdgePolicy, 2)
}
if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
r.edges[edge.ChannelID][0] = *edge
if edge.IsNode1() {
r.edges[chanID][0] = edge
} else {
r.edges[edge.ChannelID][1] = *edge
r.edges[chanID][1] = edge
}
return nil
@ -218,7 +220,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
r.mu.Lock()
defer r.mu.Unlock()
chans := make(map[uint64]channeldb.ChannelEdge)
for _, info := range r.infos {
info := info
@ -230,9 +231,9 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
for _, edges := range r.edges {
edges := edges
edge := chans[edges[0].ChannelID]
edge.Policy1 = &edges[0]
chans[edges[0].ChannelID] = edge
edge := chans[edges[0].SCID().ToUint64()]
edge.Policy1 = edges[0]
chans[edges[0].SCID().ToUint64()] = edge
}
for _, channel := range chans {
@ -240,7 +241,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
return err
}
}
return nil
}
@ -271,14 +271,14 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
return chanInfo, nil, nil, nil
}
var edge1 *models.ChannelEdgePolicy1
var edge1 models.ChannelEdgePolicy
if !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}) {
edge1 = &edges[0]
edge1 = edges[0]
}
var edge2 *models.ChannelEdgePolicy1
var edge2 models.ChannelEdgePolicy
if !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}) {
edge2 = &edges[1]
edge2 = edges[1]
}
return chanInfo, edge1, edge2, nil
@ -379,15 +379,21 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
}
switch {
case flags&lnwire.ChanUpdateDirection == 0 &&
!reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}):
case flags&lnwire.ChanUpdateDirection == 0 && edges[0] != nil:
switch edge := edges[0].(type) {
case *models.ChannelEdgePolicy1:
return !timestamp.After(edge.LastUpdate)
default:
panic(fmt.Sprintf("unhandled: %T", edges[0]))
}
return !timestamp.After(edges[0].LastUpdate)
case flags&lnwire.ChanUpdateDirection == 1 &&
!reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}):
return !timestamp.After(edges[1].LastUpdate)
case flags&lnwire.ChanUpdateDirection == 1 && edges[1] != nil:
switch edge := edges[1].(type) {
case *models.ChannelEdgePolicy1:
return !timestamp.After(edge.LastUpdate)
default:
panic(fmt.Sprintf("unhandled: %T", edges[1]))
}
default:
return false

@ -1512,49 +1512,35 @@ type routingMsg struct {
// ApplyChannelUpdate validates a channel update and if valid, applies it to the
// database. It returns a bool indicating whether the updates were successful.
func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool {
ch, _, _, err := b.GetChannelByID(msg.ShortChannelID)
func (b *Builder) ApplyChannelUpdate(msg lnwire.ChannelUpdate) bool {
ch, _, _, err := b.GetChannelByID(msg.SCID())
if err != nil {
log.Errorf("Unable to retrieve channel by id: %v", err)
return false
}
var pubKey *btcec.PublicKey
switch msg.ChannelFlags & lnwire.ChanUpdateDirection {
case 0:
if msg.IsNode1() {
pubKey, _ = ch.NodeKey1()
case 1:
} else {
pubKey, _ = ch.NodeKey2()
}
// Exit early if the pubkey cannot be decided.
if pubKey == nil {
log.Errorf("Unable to decide pubkey with ChannelFlags=%v",
msg.ChannelFlags)
return false
}
err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg)
if err != nil {
log.Errorf("Unable to validate channel update: %v", err)
return false
}
err = b.UpdateEdge(&models.ChannelEdgePolicy1{
SigBytes: msg.Signature.ToSignatureBytes(),
ChannelID: msg.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(msg.Timestamp), 0),
MessageFlags: msg.MessageFlags,
ChannelFlags: msg.ChannelFlags,
TimeLockDelta: msg.TimeLockDelta,
MinHTLC: msg.HtlcMinimumMsat,
MaxHTLC: msg.HtlcMaximumMsat,
FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate),
ExtraOpaqueData: msg.ExtraOpaqueData,
})
edgePolicy, err := models.EdgePolicyFromUpdate(msg)
if err != nil {
log.Errorf("Unable to convert update message to edge "+
"policy: %v", err)
return false
}
err = b.UpdateEdge(edgePolicy)
if err != nil && !IsError(err, ErrIgnored, ErrOutdated) {
log.Errorf("Unable to apply channel update: %v", err)
return false
@ -1621,7 +1607,7 @@ func (b *Builder) AddEdge(edge models.ChannelEdgeInfo,
// considered as not fully constructed.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy1,
func (b *Builder) UpdateEdge(update models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error {
rMsg := &routingMsg{

@ -39,7 +39,7 @@ type ChannelGraphSource interface {
// UpdateEdge is used to update edge information, without this message
// edge considered as not fully constructed.
UpdateEdge(policy *models.ChannelEdgePolicy1,
UpdateEdge(policy models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error
// IsStaleNode returns true if the graph source has a node announcement

@ -182,7 +182,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
return r, nil
}
func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1,
func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ lnwire.ChannelUpdate,
_ *btcec.PublicKey, _ *models.CachedEdgePolicy) bool {
return false
@ -702,7 +702,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
return args.Get(0).(*route.Route), args.Error(1)
}
func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
func (m *mockPaymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {
args := m.Called(msg, pubKey, policy)

@ -897,7 +897,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route,
// SendToRoute where there's no payment lifecycle.
if p.paySession != nil {
policy = p.paySession.GetAdditionalEdgePolicy(
errSource, update.ShortChannelID.ToUint64(),
errSource, update.SCID().ToUint64(),
)
if policy != nil {
isAdditionalEdge = true
@ -907,7 +907,8 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route,
// Apply channel update to additional edge policy.
if isAdditionalEdge {
if !p.paySession.UpdateAdditionalEdge(
update, errSource, policy) {
update, errSource, policy,
) {
log.Debugf("Invalid channel update received: node=%v",
errVertex)

@ -144,8 +144,9 @@ type PaymentSession interface {
// (private channels) and applies the update from the message. Returns
// a boolean to indicate whether the update has been applied without
// error.
UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool
UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey,
policy *models.CachedEdgePolicy) bool
// GetAdditionalEdgePolicy uses the public key and channel ID to query
// the ephemeral channel edge policy for additional edges. Returns a nil
@ -431,7 +432,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// validates the message signature and checks it's up to date, then applies the
// updates to the supplied policy. It returns a boolean to indicate whether
// there's an error when applying the updates.
func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
func (p *paymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {
// Validate the message signature.
@ -442,10 +443,12 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
return false
}
fwdingPolicy := msg.ForwardingPolicy()
// Update channel policy for the additional edge.
policy.TimeLockDelta = msg.TimeLockDelta
policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee)
policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate)
policy.TimeLockDelta = fwdingPolicy.TimeLockDelta
policy.FeeBaseMSat = fwdingPolicy.BaseFee
policy.FeeProportionalMillionths = fwdingPolicy.FeeRate
log.Debugf("New private channel update applied: %v",
lnutils.SpewLogClosure(msg))

@ -284,7 +284,7 @@ type Config struct {
// ApplyChannelUpdate can be called to apply a new channel update to the
// graph that we received from a payment failure.
ApplyChannelUpdate func(msg *lnwire.ChannelUpdate1) bool
ApplyChannelUpdate func(msg lnwire.ChannelUpdate) bool
// ClosedSCIDs is used by the router to fetch closed channels.
//