diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index be5ba80fd..fb3a28458 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -527,11 +527,18 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { } case lnwire.CodeTemporaryChannelFailure: - update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate( - f.packet.incomingChanID, + update := f.htlcSwitch.failAliasUpdate( + f.packet.incomingChanID, true, ) - if err != nil { - return err + if update == nil { + // Fallback to the original, non-alias behavior. + var err error + update, err = f.htlcSwitch.cfg.FetchLastChannelUpdate( + f.packet.incomingChanID, + ) + if err != nil { + return err + } } failureMsg = lnwire.NewTemporaryChannelFailure(update) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 6ba8966bf..be132f135 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -69,6 +69,36 @@ type dustHandler interface { getDustClosure() dustClosure } +// scidAliasHandler is an interface that the ChannelLink implements so it can +// properly handle option_scid_alias channels. +type scidAliasHandler interface { + // attachFailAliasUpdate allows the link to properly fail incoming + // HTLCs on option_scid_alias channels. + attachFailAliasUpdate(failClosure func( + sid lnwire.ShortChannelID, + incoming bool) *lnwire.ChannelUpdate) + + // getAliases fetches the link's underlying aliases. This is used by + // the Switch to determine whether to forward an HTLC and where to + // forward an HTLC. + getAliases() []lnwire.ShortChannelID + + // isZeroConf returns whether or not the underlying channel is a + // zero-conf channel. + isZeroConf() bool + + // negotiatedAliasFeature returns whether the option-scid-alias feature + // bit was negotiated. + negotiatedAliasFeature() bool + + // confirmedScid returns the confirmed SCID for a zero-conf channel. + confirmedScid() lnwire.ShortChannelID + + // zeroConfConfirmed returns whether or not the zero-conf channel has + // confirmed. + zeroConfConfirmed() bool +} + // ChannelUpdateHandler is an interface that provides methods that allow // sending lnwire.Message to the underlying link as well as querying state. type ChannelUpdateHandler interface { @@ -138,6 +168,13 @@ type ChannelLink interface { // Embed the dustHandler interface. dustHandler + // Embed the scidAliasHandler interface. + scidAliasHandler + + // IsUnadvertised returns true if the underlying channel is + // unadvertised. + IsUnadvertised() bool + // ChannelPoint returns the channel outpoint for the channel link. ChannelPoint() *wire.OutPoint @@ -165,7 +202,7 @@ type ChannelLink interface { CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, amtToForward lnwire.MilliSatoshi, incomingTimeout, outgoingTimeout uint32, - heightNow uint32) *LinkError + heightNow uint32, scid lnwire.ShortChannelID) *LinkError // CheckHtlcTransit should return a nil error if the passed HTLC details // satisfy the current channel policy. Otherwise, a LinkError with a diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 86390e5b2..f2525ba86 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -294,6 +294,15 @@ type ChannelLinkConfig struct { // HtlcNotifier is an instance of a htlcNotifier which we will pipe htlc // events through. HtlcNotifier htlcNotifier + + // FailAliasUpdate is a function used to fail an HTLC for an + // option_scid_alias channel. + FailAliasUpdate func(sid lnwire.ShortChannelID, + incoming bool) *lnwire.ChannelUpdate + + // GetAliases is used by the link and switch to fetch the set of + // aliases for a given link. + GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID } // shutdownReq contains an error channel that will be used by the channelLink @@ -581,6 +590,12 @@ func (l *channelLink) markReestablished() { atomic.StoreInt32(&l.reestablished, 1) } +// IsUnadvertised returns true if the underlying channel is unadvertised. +func (l *channelLink) IsUnadvertised() bool { + state := l.channel.State() + return state.ChannelFlags&lnwire.FFAnnounceChannel == 0 +} + // sampleNetworkFee samples the current fee rate on the network to get into the // chain in a timely manner. The returned value is expressed in fee-per-kw, as // this is the native rate used when computing the fee for commitment @@ -629,14 +644,33 @@ func shouldAdjustCommitFee(netFee, chanFee, } } -// createFailureWithUpdate retrieves this link's last channel update message and -// passes it into the callback. It expects a fully populated failure message. -func (l *channelLink) createFailureWithUpdate( - cb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage) lnwire.FailureMessage { +// failCb is used to cut down on the argument verbosity. +type failCb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - return &lnwire.FailTemporaryNodeFailure{} +// createFailureWithUpdate creates a ChannelUpdate when failing an incoming or +// outgoing HTLC. It may return a FailureMessage that references a channel's +// alias. If the channel does not have an alias, then the regular channel +// update from disk will be returned. +func (l *channelLink) createFailureWithUpdate(incoming bool, + outgoingScid lnwire.ShortChannelID, cb failCb) lnwire.FailureMessage { + + // Determine which SCID to use in case we need to use aliases in the + // ChannelUpdate. + scid := outgoingScid + if incoming { + scid = l.ShortChanID() + } + + // Try using the FailAliasUpdate function. If it returns nil, fallback + // to the non-alias behavior. + update := l.cfg.FailAliasUpdate(scid, incoming) + if update == nil { + // Fallback to the non-alias behavior. + var err error + update, err = l.cfg.FetchLastChannelUpdate(l.ShortChanID()) + if err != nil { + return &lnwire.FailTemporaryNodeFailure{} + } } return cb(update) @@ -697,6 +731,28 @@ func (l *channelLink) syncChanStates() error { fundingLockedMsg := lnwire.NewFundingLocked( l.ChanID(), nextRevocation, ) + + // For channels that negotiated the option-scid-alias + // feature bit, ensure that we send over the alias in + // the funding_locked message. We'll send the first + // alias we find for the channel since it does not + // matter which alias we send. We'll error out if no + // aliases are found. + if l.negotiatedAliasFeature() { + aliases := l.getAliases() + if len(aliases) == 0 { + // This shouldn't happen since we + // always add at least one alias before + // the channel reaches the link. + return fmt.Errorf("no aliases found") + } + + // getAliases returns a copy of the alias slice + // so it is ok to use a pointer to the first + // entry. + fundingLockedMsg.AliasScid = &aliases[0] + } + err = l.cfg.Peer.SendMessage(false, fundingLockedMsg) if err != nil { return fmt.Errorf("unable to re-send "+ @@ -2250,29 +2306,7 @@ func (l *channelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { return hop.Source, err } - sid := l.channel.ShortChanID() - - l.log.Infof("updating to short_chan_id=%v for chan_id=%v", sid, chanID) - - l.Lock() - l.shortChanID = sid - l.Unlock() - - go func() { - err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{ - ShortChanID: sid, - }) - if err != nil { - l.log.Errorf("unable to update signals") - } - }() - - // Now that the short channel ID has been properly updated, we can begin - // garbage collecting any forwarding packages we create. - l.wg.Add(1) - go l.fwdPkgGarbager() - - return sid, nil + return hop.Source, nil } // ChanID returns the channel ID for the channel link. The channel ID is a more @@ -2362,6 +2396,58 @@ func dustHelper(chantype channeldb.ChannelType, localDustLimit, return isDust } +// zeroConfConfirmed returns whether or not the zero-conf channel has +// confirmed on-chain. +// +// Part of the scidAliasHandler interface. +func (l *channelLink) zeroConfConfirmed() bool { + return l.channel.State().ZeroConfConfirmed() +} + +// confirmedScid returns the confirmed SCID for a zero-conf channel. This +// should not be called for non-zero-conf channels. +// +// Part of the scidAliasHandler interface. +func (l *channelLink) confirmedScid() lnwire.ShortChannelID { + return l.channel.State().ZeroConfRealScid() +} + +// isZeroConf returns whether or not the underlying channel is a zero-conf +// channel. +// +// Part of the scidAliasHandler interface. +func (l *channelLink) isZeroConf() bool { + return l.channel.State().IsZeroConf() +} + +// negotiatedAliasFeature returns whether or not the underlying channel has +// negotiated the option-scid-alias feature bit. This will be true for both +// option-scid-alias and zero-conf channel-types. It will also be true for +// channels with the feature bit but without the above channel-types. +// +// Part of the scidAliasFeature interface. +func (l *channelLink) negotiatedAliasFeature() bool { + return l.channel.State().NegotiatedAliasFeature() +} + +// getAliases returns the set of aliases for the underlying channel. +// +// Part of the scidAliasHandler interface. +func (l *channelLink) getAliases() []lnwire.ShortChannelID { + return l.cfg.GetAliases(l.ShortChanID()) +} + +// attachFailAliasUpdate sets the link's FailAliasUpdate function. +// +// Part of the scidAliasHandler interface. +func (l *channelLink) attachFailAliasUpdate(closure func( + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + + l.Lock() + l.cfg.FailAliasUpdate = closure + l.Unlock() +} + // AttachMailBox updates the current mailbox used by this link, and hooks up // the mailbox's message and packet outboxes to the link's upstream and // downstream chans, respectively. @@ -2405,7 +2491,7 @@ func (l *channelLink) UpdateForwardingPolicy(newPolicy ForwardingPolicy) { func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, incomingTimeout, outgoingTimeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -2414,6 +2500,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // First check whether the outgoing htlc satisfies the channel policy. err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, + originalScid, ) if err != nil { return err @@ -2437,13 +2524,10 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewFeeInsufficient( - amtToForward, *upd, - ) - }, - ) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewFeeInsufficient(amtToForward, *upd) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) } @@ -2459,13 +2543,12 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Grab the latest routing policy so the sending node is up to // date with our current policy. - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewIncorrectCltvExpiry( - incomingTimeout, *upd, - ) - }, - ) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewIncorrectCltvExpiry( + incomingTimeout, *upd, + ) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) } @@ -2485,8 +2568,11 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, policy := l.cfg.FwrdingPolicy l.RUnlock() + // We pass in hop.Source here as this is only used in the Switch when + // trying to send over a local link. This causes the fallback mechanism + // to occur. return l.canSendHtlc( - policy, payHash, amt, timeout, heightNow, + policy, payHash, amt, timeout, heightNow, hop.Source, ) } @@ -2494,7 +2580,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // the channel's amount and time lock constraints. func (l *channelLink) canSendHtlc(policy ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { // As our first sanity check, we'll ensure that the passed HTLC isn't // too small for the next hop. If so, then we'll cancel the HTLC @@ -2506,13 +2592,10 @@ func (l *channelLink) canSendHtlc(policy ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewAmountBelowMinimum( - amt, *upd, - ) - }, - ) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewAmountBelowMinimum(amt, *upd) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) } @@ -2524,11 +2607,10 @@ func (l *channelLink) canSendHtlc(policy ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up-to-date data. - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewTemporaryChannelFailure(upd) - }, - ) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure(upd) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewDetailedLinkError(failure, OutgoingFailureHTLCExceedsMax) } @@ -2539,11 +2621,11 @@ func (l *channelLink) canSendHtlc(policy ForwardingPolicy, l.log.Warnf("htlc(%x) has an expiry that's too soon: "+ "outgoing_expiry=%v, best_height=%v", payHash[:], timeout, heightNow) - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewExpiryTooSoon(*upd) - }, - ) + + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewExpiryTooSoon(*upd) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) } @@ -2560,11 +2642,10 @@ func (l *channelLink) canSendHtlc(policy ForwardingPolicy, if amt > l.Bandwidth() { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewTemporaryChannelFailure(upd) - }, - ) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure(upd) + } + failure := l.createFailureWithUpdate(false, originalScid, cb) return NewDetailedLinkError( failure, OutgoingFailureInsufficientBalance, ) @@ -3009,12 +3090,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, l.log.Errorf("unable to encode the "+ "remaining route %v", err) + cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure(upd) + } + failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewTemporaryChannelFailure( - upd, - ) - }, + true, hop.Source, cb, ) l.sendHTLCError( diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 478ab29a5..4708325eb 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1874,6 +1874,12 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( return nil } + getAliases := func( + base lnwire.ShortChannelID) []lnwire.ShortChannelID { + + return nil + } + // Instantiate with a long interval, so that we can precisely control // the firing via force feeding. bticker := ticker.NewForce(time.Hour) @@ -1917,6 +1923,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( NotifyActiveChannel: func(wire.OutPoint) {}, NotifyInactiveChannel: func(wire.OutPoint) {}, HtlcNotifier: aliceSwitch.cfg.HtlcNotifier, + GetAliases: getAliases, } aliceLink := NewChannelLink(aliceCfg, aliceLc.channel) @@ -4325,6 +4332,12 @@ func (h *persistentLinkHarness) restartLink( return nil } + getAliases := func( + base lnwire.ShortChannelID) []lnwire.ShortChannelID { + + return nil + } + // Instantiate with a long interval, so that we can precisely control // the firing via force feeding. bticker := ticker.NewForce(time.Hour) @@ -4371,6 +4384,7 @@ func (h *persistentLinkHarness) restartLink( NotifyInactiveChannel: func(wire.OutPoint) {}, HtlcNotifier: aliceSwitch.cfg.HtlcNotifier, SyncStates: syncStates, + GetAliases: getAliases, } aliceLink := NewChannelLink(aliceCfg, aliceChannel) @@ -5571,6 +5585,12 @@ func TestCheckHtlcForward(t *testing.T) { return &lnwire.ChannelUpdate{}, nil } + failAliasUpdate := func(sid lnwire.ShortChannelID, + incoming bool) *lnwire.ChannelUpdate { + + return nil + } + testChannel, _, fCleanUp, err := createTestChannel( alicePrivKey, bobPrivKey, 100000, 100000, 1000, 1000, lnwire.ShortChannelID{}, @@ -5596,11 +5616,13 @@ func TestCheckHtlcForward(t *testing.T) { channel: testChannel.channel, } + link.attachFailAliasUpdate(failAliasUpdate) + var hash [32]byte t.Run("satisfied", func(t *testing.T) { result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, 0) + 200, 150, 0, lnwire.ShortChannelID{}) if result != nil { t.Fatalf("expected policy to be satisfied") } @@ -5608,7 +5630,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("below minhtlc", func(t *testing.T) { result := link.CheckHtlcForward(hash, 100, 50, - 200, 150, 0) + 200, 150, 0, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") } @@ -5616,7 +5638,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("above maxhtlc", func(t *testing.T) { result := link.CheckHtlcForward(hash, 1500, 1200, - 200, 150, 0) + 200, 150, 0, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") } @@ -5624,7 +5646,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("insufficient fee", func(t *testing.T) { result := link.CheckHtlcForward(hash, 1005, 1000, - 200, 150, 0) + 200, 150, 0, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") } @@ -5632,7 +5654,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("expiry too soon", func(t *testing.T) { result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, 190) + 200, 150, 190, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") } @@ -5640,7 +5662,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("incorrect cltv expiry", func(t *testing.T) { result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 190, 0) + 200, 190, 0, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") } @@ -5650,7 +5672,7 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. result := link.CheckHtlcForward(hash, 1500, 1000, - 10200, 10100, 0) + 10200, 10100, 0, lnwire.ShortChannelID{}) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") } diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index a8e6ed31e..de85f7818 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -91,10 +91,6 @@ type mailBoxConfig struct { // belongs to. shortChanID lnwire.ShortChannelID - // fetchUpdate retrieves the most recent channel update for the channel - // this mailbox belongs to. - fetchUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) - // forwardPackets send a varidic number of htlcPackets to the switch to // be routed. A quit channel should be provided so that the call can // properly exit during shutdown. @@ -107,6 +103,11 @@ type mailBoxConfig struct { // have not been yet been delivered. The computed deadline will expiry // this long after the Adds are added via AddPacket. expiry time.Duration + + // failMailboxUpdate is used to fail an expired HTLC and use the + // correct SCID if the underlying channel uses aliases. + failMailboxUpdate func(outScid, + mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage } // memoryMailBox is an implementation of the MailBox struct backed by purely @@ -710,13 +711,9 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { // Create a temporary channel failure which we will send back to our // peer if this is a forward, or report to the user if the failed // payment was locally initiated. - var failure lnwire.FailureMessage - update, err := m.cfg.fetchUpdate(m.cfg.shortChanID) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure(update) - } + failure := m.cfg.failMailboxUpdate( + pkt.originalOutgoingChanID, m.cfg.shortChanID, + ) // If the payment was locally initiated (which is indicated by a nil // obfuscator), we do not need to encrypt it back to the sender. @@ -817,10 +814,6 @@ type mailOrchConfig struct { // properly exit during shutdown. forwardPackets func(chan struct{}, ...*htlcPacket) error - // fetchUpdate retrieves the most recent channel update for the channel - // this mailbox belongs to. - fetchUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) - // clock is a time source for the generated mailboxes. clock clock.Clock @@ -828,6 +821,11 @@ type mailOrchConfig struct { // have not been yet been delivered. The computed deadline will expiry // this long after the Adds are added to a mailbox via AddPacket. expiry time.Duration + + // failMailboxUpdate is used to fail an expired HTLC and use the + // correct SCID if the underlying channel uses aliases. + failMailboxUpdate func(outScid, + mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage } // newMailOrchestrator initializes a fresh mailOrchestrator. @@ -881,11 +879,11 @@ func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox( mailbox, ok := mo.mailboxes[chanID] if !ok { mailbox = newMemoryMailBox(&mailBoxConfig{ - shortChanID: shortChanID, - fetchUpdate: mo.cfg.fetchUpdate, - forwardPackets: mo.cfg.forwardPackets, - clock: mo.cfg.clock, - expiry: mo.cfg.expiry, + shortChanID: shortChanID, + forwardPackets: mo.cfg.forwardPackets, + clock: mo.cfg.clock, + expiry: mo.cfg.expiry, + failMailboxUpdate: mo.cfg.failMailboxUpdate, }) mailbox.Start() mo.mailboxes[chanID] = mailbox diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index a55dac2fc..4c4ceee56 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -201,17 +201,18 @@ func newMailboxContext(t *testing.T, startTime time.Time, clock: clock.NewTestClock(startTime), forwards: make(chan *htlcPacket, 1), } - ctx.mailbox = newMemoryMailBox(&mailBoxConfig{ - fetchUpdate: func(sid lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { - return &lnwire.ChannelUpdate{ - ShortChannelID: sid, - }, nil - }, - forwardPackets: ctx.forward, - clock: ctx.clock, - expiry: expiry, + failMailboxUpdate := func(outScid, + mboxScid lnwire.ShortChannelID) lnwire.FailureMessage { + + return &lnwire.FailTemporaryNodeFailure{} + } + + ctx.mailbox = newMemoryMailBox(&mailBoxConfig{ + failMailboxUpdate: failMailboxUpdate, + forwardPackets: ctx.forward, + clock: ctx.clock, + expiry: expiry, }) ctx.mailbox.Start() @@ -660,15 +661,15 @@ func testMailBoxDust(t *testing.T, chantype channeldb.ChannelType) { func TestMailOrchestrator(t *testing.T) { t.Parallel() + failMailboxUpdate := func(outScid, + mboxScid lnwire.ShortChannelID) lnwire.FailureMessage { + + return &lnwire.FailTemporaryNodeFailure{} + } + // First, we'll create a new instance of our orchestrator. mo := newMailOrchestrator(&mailOrchConfig{ - fetchUpdate: func(sid lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { - - return &lnwire.ChannelUpdate{ - ShortChannelID: sid, - }, nil - }, + failMailboxUpdate: failMailboxUpdate, forwardPackets: func(_ chan struct{}, pkts ...*htlcPacket) error { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 5acfeeb81..b209e2c51 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -15,6 +15,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" @@ -33,6 +34,10 @@ import ( "github.com/lightningnetwork/lnd/ticker" ) +func isAlias(scid lnwire.ShortChannelID) bool { + return scid.BlockHeight >= 16_000_000 && scid.BlockHeight < 16_250_000 +} + type mockPreimageCache struct { sync.Mutex preimageMap map[lntypes.Hash]lntypes.Preimage @@ -180,6 +185,12 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) } } + signAliasUpdate := func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + error) { + + return testSig, nil + } + cfg := Config{ DB: db, FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, @@ -188,21 +199,27 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) FwdingLog: &mockForwardingLog{ events: make(map[time.Time]channeldb.ForwardingEvent), }, - FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - return &lnwire.ChannelUpdate{}, nil + FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) ( + *lnwire.ChannelUpdate, error) { + + return &lnwire.ChannelUpdate{ + ShortChannelID: scid, + }, nil }, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, - FwdEventTicker: ticker.NewForce(DefaultFwdEventInterval), - LogEventTicker: ticker.NewForce(DefaultLogInterval), - AckEventTicker: ticker.NewForce(DefaultAckInterval), - HtlcNotifier: &mockHTLCNotifier{}, - Clock: clock.NewDefaultClock(), - HTLCExpiry: time.Hour, - DustThreshold: DefaultDustThreshold, + FwdEventTicker: ticker.NewForce(DefaultFwdEventInterval), + LogEventTicker: ticker.NewForce(DefaultLogInterval), + AckEventTicker: ticker.NewForce(DefaultAckInterval), + HtlcNotifier: &mockHTLCNotifier{}, + Clock: clock.NewDefaultClock(), + HTLCExpiry: time.Hour, + DustThreshold: DefaultDustThreshold, + SignAliasUpdate: signAliasUpdate, + IsAlias: isAlias, } return New(cfg, startingHeight) @@ -658,6 +675,11 @@ type mockChannelLink struct { shortChanID lnwire.ShortChannelID + // Only used for zero-conf channels. + realScid lnwire.ShortChannelID + + aliases []lnwire.ShortChannelID + chanID lnwire.ChannelID peer lnpeer.Peer @@ -668,11 +690,22 @@ type mockChannelLink struct { eligible bool + unadvertised bool + + zeroConf bool + + optionFeature bool + htlcID uint64 checkHtlcTransitResult *LinkError checkHtlcForwardResult *LinkError + + failAliasUpdate func(sid lnwire.ShortChannelID, + incoming bool) *lnwire.ChannelUpdate + + confirmedZC bool } // completeCircuit is a helper method for adding the finalized payment circuit @@ -712,16 +745,39 @@ func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error { } func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, - shortChanID lnwire.ShortChannelID, peer lnpeer.Peer, eligible bool, + shortChanID, realScid lnwire.ShortChannelID, peer lnpeer.Peer, + eligible, unadvertised, zeroConf, optionFeature bool, ) *mockChannelLink { - return &mockChannelLink{ - htlcSwitch: htlcSwitch, - chanID: chanID, - shortChanID: shortChanID, - peer: peer, - eligible: eligible, + aliases := make([]lnwire.ShortChannelID, 0) + var realConfirmed bool + + if zeroConf { + aliases = append(aliases, shortChanID) } + + if realScid != hop.Source { + realConfirmed = true + } + + return &mockChannelLink{ + htlcSwitch: htlcSwitch, + chanID: chanID, + shortChanID: shortChanID, + realScid: realScid, + peer: peer, + eligible: eligible, + unadvertised: unadvertised, + zeroConf: zeroConf, + optionFeature: optionFeature, + aliases: aliases, + confirmedZC: realConfirmed, + } +} + +// addAlias is not part of any interface method. +func (f *mockChannelLink) addAlias(alias lnwire.ShortChannelID) { + f.aliases = append(f.aliases, alias) } func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error { @@ -750,7 +806,8 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) { func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) { } func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, - lnwire.MilliSatoshi, uint32, uint32, uint32) *LinkError { + lnwire.MilliSatoshi, uint32, uint32, uint32, + lnwire.ShortChannelID) *LinkError { return f.checkHtlcForwardResult } @@ -772,6 +829,32 @@ func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { mailBox.SetDustClosure(f.getDustClosure()) } +func (f *mockChannelLink) attachFailAliasUpdate(closure func( + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + + f.failAliasUpdate = closure +} + +func (f *mockChannelLink) getAliases() []lnwire.ShortChannelID { + return f.aliases +} + +func (f *mockChannelLink) isZeroConf() bool { + return f.zeroConf +} + +func (f *mockChannelLink) negotiatedAliasFeature() bool { + return f.optionFeature +} + +func (f *mockChannelLink) confirmedScid() lnwire.ShortChannelID { + return f.realScid +} + +func (f *mockChannelLink) zeroConfConfirmed() bool { + return f.confirmedZC +} + func (f *mockChannelLink) Start() error { f.mailBox.ResetMessages() f.mailBox.ResetPackets() @@ -788,6 +871,7 @@ func (f *mockChannelLink) EligibleToForward() bool { return func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } func (f *mockChannelLink) ShutdownIfChannelClean() error { return nil } func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } +func (f *mockChannelLink) IsUnadvertised() bool { return f.unadvertised } func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { f.eligible = true return f.shortChanID, nil diff --git a/htlcswitch/packet.go b/htlcswitch/packet.go index f971ad90b..ddd524d73 100644 --- a/htlcswitch/packet.go +++ b/htlcswitch/packet.go @@ -96,6 +96,13 @@ type htlcPacket struct { // customRecords are user-defined records in the custom type range that // were included in the payload. customRecords record.CustomSet + + // originalOutgoingChanID is used when sending back failure messages. + // It is only used for forwarded Adds on option_scid_alias channels. + // This is to avoid possible confusion if a payer uses the public SCID + // but receives a channel_update with the alias SCID. Instead, the + // payer should receive a channel_update with the public SCID. + originalOutgoingChanID lnwire.ShortChannelID } // inKey returns the circuit key used to identify the incoming htlc. diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 72d6f5968..4050d700f 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -200,6 +201,16 @@ type Config struct { // DustThreshold is the threshold in milli-satoshis after which we'll // fail incoming or outgoing dust payments for a particular channel. DustThreshold lnwire.MilliSatoshi + + // SignAliasUpdate is used when sending FailureMessages backwards for + // option_scid_alias channels. This avoids a potential privacy leak by + // replacing the public, confirmed SCID with the alias in the + // ChannelUpdate. + SignAliasUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + error) + + // IsAlias returns whether or not a given SCID is an alias. + IsAlias func(scid lnwire.ShortChannelID) bool } // Switch is the central messaging bus for all incoming/outgoing HTLCs. @@ -247,8 +258,7 @@ type Switch struct { indexMtx sync.RWMutex // pendingLinkIndex holds links that have not had their final, live - // short_chan_id assigned. These links can be transitioned into the - // primary linkIndex by using UpdateShortChanID to load their live id. + // short_chan_id assigned. pendingLinkIndex map[lnwire.ChannelID]ChannelLink // links is a map of channel id and channel link which manages @@ -311,6 +321,21 @@ type Switch struct { // contractcourt. This is used so the Switch can properly forward them, // even on restarts. resMsgStore *resolutionStore + + // aliasToReal is a map used for option-scid-alias feature-bit links. + // The alias SCID is the key and the real, confirmed SCID is the value. + // If the channel is unconfirmed, there will not be a mapping for it. + // Since channels can have multiple aliases, this map is essentially a + // N->1 mapping for a channel. This MUST be accessed with the indexMtx. + aliasToReal map[lnwire.ShortChannelID]lnwire.ShortChannelID + + // baseIndex is a map used for option-scid-alias feature-bit links. + // The value is the SCID of the link's ShortChannelID. This value may + // be an alias for zero-conf channels or a confirmed SCID for + // non-zero-conf channels with the option-scid-alias feature-bit. The + // key includes the value itself and also any other aliases. This MUST + // be accessed with the indexMtx. + baseIndex map[lnwire.ShortChannelID]lnwire.ShortChannelID } // New creates the new instance of htlc switch. @@ -345,11 +370,14 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { quit: make(chan struct{}), } + s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) + s.baseIndex = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) + s.mailOrchestrator = newMailOrchestrator(&mailOrchConfig{ - fetchUpdate: s.cfg.FetchLastChannelUpdate, - forwardPackets: s.ForwardPackets, - clock: s.cfg.Clock, - expiry: s.cfg.HTLCExpiry, + forwardPackets: s.ForwardPackets, + clock: s.cfg.Clock, + expiry: s.cfg.HTLCExpiry, + failMailboxUpdate: s.failMailboxUpdate, }) return s, nil @@ -725,14 +753,28 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{}, // failures. if len(failedPackets) > 0 { var failure lnwire.FailureMessage - update, err := s.cfg.FetchLastChannelUpdate( - failedPackets[0].incomingChanID, - ) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} + incomingID := failedPackets[0].incomingChanID + + // If the incoming channel is an option_scid_alias channel, + // then we'll need to replace the SCID in the ChannelUpdate. + update := s.failAliasUpdate(incomingID, true) + if update == nil { + // Fallback to the original non-option behavior. + update, err := s.cfg.FetchLastChannelUpdate( + incomingID, + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure( + update, + ) + } } else { + // This is an option_scid_alias channel. failure = lnwire.NewTemporaryChannelFailure(update) } + linkError := NewDetailedLinkError( failure, OutgoingFailureIncompleteForward, ) @@ -804,10 +846,29 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( // Try to find links by node destination. s.indexMtx.RLock() link, err := s.getLinkByShortID(pkt.outgoingChanID) - s.indexMtx.RUnlock() + defer s.indexMtx.RUnlock() if err != nil { - log.Errorf("Link %v not found", pkt.outgoingChanID) - return nil, NewLinkError(&lnwire.FailUnknownNextPeer{}) + // If the link was not found for the outgoingChanID, an outside + // subsystem may be using the confirmed SCID of a zero-conf + // channel. In this case, we'll consult the Switch maps to see + // if an alias exists and use the alias to lookup the link. + // This extra step is a consequence of not updating the Switch + // forwardingIndex when a zero-conf channel is confirmed. We + // don't need to change the outgoingChanID since the link will + // do that upon receiving the packet. + baseScid, ok := s.baseIndex[pkt.outgoingChanID] + if !ok { + log.Errorf("Link %v not found", pkt.outgoingChanID) + return nil, NewLinkError(&lnwire.FailUnknownNextPeer{}) + } + + // The base SCID was found, so we'll use that to fetch the + // link. + link, err = s.getLinkByShortID(baseScid) + if err != nil { + log.Errorf("Link %v not found", baseScid) + return nil, NewLinkError(&lnwire.FailUnknownNextPeer{}) + } } if !link.EligibleToForward() { @@ -1043,8 +1104,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // same incoming and outgoing channel. If our node does not // allow forwards of this nature, we fail the htlc early. This // check is in place to disallow inefficiently routed htlcs from - // locking up our balance. - linkErr := checkCircularForward( + // locking up our balance. With channels where the + // option-scid-alias feature was negotiated, we also have to be + // sure that the IDs aren't the same since one or both could be + // an alias. + linkErr := s.checkCircularForward( packet.incomingChanID, packet.outgoingChanID, s.cfg.AllowCircularRoute, htlc.PaymentHash, ) @@ -1053,7 +1117,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { } s.indexMtx.RLock() - targetLink, err := s.getLinkByShortID(packet.outgoingChanID) + targetLink, err := s.getLinkByMapping(packet) if err != nil { s.indexMtx.RUnlock() @@ -1101,6 +1165,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { htlc.PaymentHash, packet.incomingAmount, packet.amount, packet.incomingTimeout, packet.outgoingTimeout, currentHeight, + packet.originalOutgoingChanID, ) } @@ -1306,12 +1371,51 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // checkCircularForward checks whether a forward is circular (arrives and // departs on the same link) and returns a link error if the switch is // configured to disallow this behaviour. -func checkCircularForward(incoming, outgoing lnwire.ShortChannelID, +func (s *Switch) checkCircularForward(incoming, outgoing lnwire.ShortChannelID, allowCircular bool, paymentHash lntypes.Hash) *LinkError { - // If the route is not circular we do not need to perform any further - // checks. - if incoming != outgoing { + // If they are equal, we can skip the alias mapping checks. + if incoming == outgoing { + // The switch may be configured to allow circular routes, so + // just log and return nil. + if allowCircular { + log.Debugf("allowing circular route over link: %v "+ + "(payment hash: %x)", incoming, paymentHash) + return nil + } + + // Otherwise, we'll return a temporary channel failure. + return NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + OutgoingFailureCircularRoute, + ) + } + + // We'll fetch the "base" SCID from the baseIndex for the incoming and + // outgoing SCIDs. If either one does not have a base SCID, then the + // two channels are not equal since one will be a channel that does not + // need a mapping and SCID equality was checked above. If the "base" + // SCIDs are equal, then this is a circular route. Otherwise, it isn't. + s.indexMtx.RLock() + incomingBaseScid, ok := s.baseIndex[incoming] + if !ok { + // This channel does not use baseIndex, bail out. + s.indexMtx.RUnlock() + return nil + } + + outgoingBaseScid, ok := s.baseIndex[outgoing] + if !ok { + // This channel does not use baseIndex, bail out. + s.indexMtx.RUnlock() + return nil + } + s.indexMtx.RUnlock() + + // Check base SCID equality. + if incomingBaseScid != outgoingBaseScid { + // The base SCIDs are not equal so these are not the same + // channel. return nil } @@ -2170,6 +2274,9 @@ func (s *Switch) AddLink(link ChannelLink) error { mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID) link.AttachMailBox(mailbox) + // Attach the Switch's failAliasUpdate function to the link. + link.attachFailAliasUpdate(s.failAliasUpdate) + if err := link.Start(); err != nil { s.removeLink(chanID) return err @@ -2196,12 +2303,14 @@ func (s *Switch) AddLink(link ChannelLink) error { // addLiveLink adds a link to all associated forwarding index, this makes it a // candidate for forwarding HTLCs. func (s *Switch) addLiveLink(link ChannelLink) { + linkScid := link.ShortChanID() + // We'll add the link to the linkIndex which lets us quickly // look up a channel when we need to close or register it, and // the forwarding index which'll be used when forwarding HTLC's // in the multi-hop setting. s.linkIndex[link.ChanID()] = link - s.forwardingIndex[link.ShortChanID()] = link + s.forwardingIndex[linkScid] = link // Next we'll add the link to the interface index so we can // quickly look up all the channels for a particular node. @@ -2210,6 +2319,42 @@ func (s *Switch) addLiveLink(link ChannelLink) { s.interfaceIndex[peerPub] = make(map[lnwire.ChannelID]ChannelLink) } s.interfaceIndex[peerPub][link.ChanID()] = link + + aliases := link.getAliases() + if link.isZeroConf() { + if link.zeroConfConfirmed() { + // Since the zero-conf channel has confirmed, we can + // populate the aliasToReal mapping. + confirmedScid := link.confirmedScid() + + for _, alias := range aliases { + s.aliasToReal[alias] = confirmedScid + } + + // Add the confirmed SCID as a key in the baseIndex. + s.baseIndex[confirmedScid] = linkScid + } + + // Now we populate the baseIndex which will be used to fetch + // the link given any of the channel's alias SCIDs or the real + // SCID. The link's SCID is an alias, so we don't need to + // special-case it like the option-scid-alias feature-bit case + // further down. + for _, alias := range aliases { + s.baseIndex[alias] = linkScid + } + } else if link.negotiatedAliasFeature() { + // The link's SCID is the confirmed SCID for non-zero-conf + // option-scid-alias feature bit channels. + for _, alias := range aliases { + s.aliasToReal[alias] = linkScid + s.baseIndex[alias] = linkScid + } + + // Since the link's SCID is confirmed, it was not included in + // the baseIndex above as a key. Add it now. + s.baseIndex[linkScid] = linkScid + } } // GetLink is used to initiate the handling of the get link command. The @@ -2245,7 +2390,21 @@ func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, s.indexMtx.RLock() defer s.indexMtx.RUnlock() - return s.getLinkByShortID(chanID) + link, err := s.getLinkByShortID(chanID) + if err != nil { + // If we failed to find the link under the passed-in SCID, we + // consult the Switch's baseIndex map to see if the confirmed + // SCID was used for a zero-conf channel. + aliasID, ok := s.baseIndex[chanID] + if !ok { + return nil, err + } + + // An alias was found, use it to lookup if a link exists. + return s.getLinkByShortID(aliasID) + } + + return link, nil } // getLinkByShortID attempts to return the link which possesses the target @@ -2261,6 +2420,93 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er return link, nil } +// getLinkByMapping attempts to fetch the link via the htlcPacket's +// outgoingChanID, possibly using a mapping. If it finds the link via mapping, +// the outgoingChanID will be changed so that an error can be properly +// attributed when looping over linkErrs in handlePacketForward. +// +// * If the outgoingChanID is an alias, we'll fetch the link regardless if it's +// public or not. +// +// * If the outgoingChanID is a confirmed SCID, we'll need to do more checks. +// - If there is no entry found in baseIndex, fetch the link. This channel +// did not have the option-scid-alias feature negotiated (which includes +// zero-conf and option-scid-alias channel-types). +// - If there is an entry found, fetch the link from forwardingIndex and +// fail if this is a private link. +// +// NOTE: This MUST be called with the indexMtx read lock held. +func (s *Switch) getLinkByMapping(pkt *htlcPacket) (ChannelLink, error) { + // Determine if this ShortChannelID is an alias or a confirmed SCID. + chanID := pkt.outgoingChanID + aliasID := s.cfg.IsAlias(chanID) + + // Set the originalOutgoingChanID so the proper channel_update can be + // sent back if the option-scid-alias feature bit was negotiated. + pkt.originalOutgoingChanID = chanID + + if aliasID { + // Since outgoingChanID is an alias, we'll fetch the link via + // baseIndex. + baseScid, ok := s.baseIndex[chanID] + if !ok { + // No mapping exists, bail. + return nil, ErrChannelLinkNotFound + } + + // A mapping exists, so use baseScid to find the link in the + // forwardingIndex. + link, ok := s.forwardingIndex[baseScid] + if !ok { + // Link not found, bail. + return nil, ErrChannelLinkNotFound + } + + // Change the packet's outgoingChanID field so that errors are + // properly attributed. + pkt.outgoingChanID = baseScid + + // Return the link without checking if it's private or not. + return link, nil + } + + // The outgoingChanID is a confirmed SCID. Attempt to fetch the base + // SCID from baseIndex. + baseScid, ok := s.baseIndex[chanID] + if !ok { + // outgoingChanID is not a key in base index meaning this + // channel did not have the option-scid-alias feature bit + // negotiated. We'll fetch the link and return it. + link, ok := s.forwardingIndex[chanID] + if !ok { + // The link wasn't found, bail out. + return nil, ErrChannelLinkNotFound + } + + return link, nil + } + + // Fetch the link whose internal SCID is baseScid. + link, ok := s.forwardingIndex[baseScid] + if !ok { + // Link wasn't found, bail out. + return nil, ErrChannelLinkNotFound + } + + // If the link is unadvertised, we fail since the real SCID was used to + // forward over it and this is a channel where the option-scid-alias + // feature bit was negotiated. + if link.IsUnadvertised() { + return nil, ErrChannelLinkNotFound + } + + // The link is public so the confirmed SCID can be used to forward over + // it. We'll also replace pkt's outgoingChanID field so errors can + // properly be attributed in the calling function. + pkt.outgoingChanID = baseScid + return link, nil +} + // HasActiveLink returns true if the given channel ID has a link in the link // index AND the link is eligible to forward. func (s *Switch) HasActiveLink(chanID lnwire.ChannelID) bool { @@ -2357,50 +2603,38 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink { return link } -// UpdateShortChanID updates the short chan ID for an existing channel. This is -// required in the case of a re-org and re-confirmation or a channel, or in the -// case that a link was added to the switch before its short chan ID was known. +// UpdateShortChanID locates the link with the passed-in chanID and updates the +// underlying channel state. This is only used in zero-conf channels to allow +// the confirmed SCID to be updated. func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID) error { s.indexMtx.Lock() defer s.indexMtx.Unlock() - // Locate the target link in the pending link index. If no such link - // exists, then we will ignore the request. - link, ok := s.pendingLinkIndex[chanID] + // Locate the target link in the link index. If no such link exists, + // then we will ignore the request. + link, ok := s.linkIndex[chanID] if !ok { return fmt.Errorf("link %v not found", chanID) } - oldShortChanID := link.ShortChanID() - - // Try to update the link's short channel ID, returning early if this - // update failed. - shortChanID, err := link.UpdateShortChanID() + // Try to update the link's underlying channel state, returning early + // if this update failed. + _, err := link.UpdateShortChanID() if err != nil { return err } - // Reject any blank short channel ids. - if shortChanID == hop.Source { - return fmt.Errorf("refusing trivial short_chan_id for chan_id=%v"+ - "live link", chanID) + // Since the zero-conf channel is confirmed, we should populate the + // aliasToReal map and update the baseIndex. + aliases := link.getAliases() + + confirmedScid := link.confirmedScid() + + for _, alias := range aliases { + s.aliasToReal[alias] = confirmedScid } - log.Infof("Updated short_chan_id for ChannelLink(%v): old=%v, new=%v", - chanID, oldShortChanID, shortChanID) - - // Since the link was in the pending state before, we will remove it - // from the pending link index and add it to the live link index so that - // it can be available in forwarding. - delete(s.pendingLinkIndex, chanID) - s.addLiveLink(link) - - // Finally, alert the mail orchestrator to the change of short channel - // ID, and deliver any unclaimed packets to the link. - mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID) - s.mailOrchestrator.BindLiveShortChanID( - mailbox, chanID, shortChanID, - ) + s.baseIndex[confirmedScid] = link.ShortChanID() return nil } @@ -2569,3 +2803,205 @@ func (s *Switch) evaluateDustThreshold(link ChannelLink, // If we reached this point, this HTLC is fine to forward. return false } + +// failMailboxUpdate is passed to the mailbox orchestrator which in turn passes +// it to individual mailboxes. It allows the mailboxes to construct a +// FailureMessage when failing back HTLC's due to expiry and may include an +// alias in the ShortChannelID field. The outgoingScid is the SCID originally +// used in the onion. The mailboxScid is the SCID that the mailbox and link +// use. The mailboxScid is only used in the non-alias case, so it is always +// the confirmed SCID. +func (s *Switch) failMailboxUpdate(outgoingScid, + mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage { + + // Try to use the failAliasUpdate function in case this is a channel + // that uses aliases. If it returns nil, we'll fallback to the original + // pre-alias behavior. + update := s.failAliasUpdate(outgoingScid, false) + if update == nil { + // Execute the fallback behavior. + var err error + update, err = s.cfg.FetchLastChannelUpdate(mailboxScid) + if err != nil { + return &lnwire.FailTemporaryNodeFailure{} + } + } + + return lnwire.NewTemporaryChannelFailure(update) +} + +// failAliasUpdate prepares a ChannelUpdate for a failed incoming or outgoing +// HTLC on a channel where the option-scid-alias feature bit was negotiated. If +// the associated channel is not one of these, this function will return nil +// and the caller is expected to handle this properly. In this case, a return +// to the original non-alias behavior is expected. +func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, + incoming bool) *lnwire.ChannelUpdate { + + // This function does not defer the unlocking because of the database + // lookups for ChannelUpdate. + s.indexMtx.RLock() + + if s.cfg.IsAlias(scid) { + // The alias SCID was used. In the incoming case this means + // the channel is zero-conf as the link sets the scid. In the + // outgoing case, the sender set the scid to use and may be + // either the alias or the confirmed one, if it exists. + realScid, ok := s.aliasToReal[scid] + if !ok { + // The real, confirmed SCID does not exist yet. Find + // the "base" SCID that the link uses via the + // baseIndex. If we can't find it, return nil. This + // means the channel is zero-conf. + baseScid, ok := s.baseIndex[scid] + s.indexMtx.RUnlock() + if !ok { + return nil + } + + update, err := s.cfg.FetchLastChannelUpdate(baseScid) + if err != nil { + return nil + } + + // Replace the baseScid with the passed-in alias. + update.ShortChannelID = scid + sig, err := s.cfg.SignAliasUpdate(update) + if err != nil { + return nil + } + + update.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return nil + } + + return update + } + + s.indexMtx.RUnlock() + + // Fetch the SCID via the confirmed SCID and replace it with + // the alias. + update, err := s.cfg.FetchLastChannelUpdate(realScid) + if err != nil { + return nil + } + + // In the incoming case, we want to ensure that we don't leak + // the UTXO in case the channel is private. In the outgoing + // case, since the alias was used, we do the same thing. + update.ShortChannelID = scid + sig, err := s.cfg.SignAliasUpdate(update) + if err != nil { + return nil + } + + update.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return nil + } + + return update + } + + // If the confirmed SCID is not in baseIndex, this is not an + // option-scid-alias or zero-conf channel. + baseScid, ok := s.baseIndex[scid] + if !ok { + s.indexMtx.RUnlock() + return nil + } + + // Fetch the link so we can get an alias to use in the ShortChannelID + // of the ChannelUpdate. + link, ok := s.forwardingIndex[baseScid] + s.indexMtx.RUnlock() + if !ok { + // This should never happen, but if it does for some reason, + // fallback to the old behavior. + return nil + } + + aliases := link.getAliases() + if len(aliases) == 0 { + // This should never happen, but if it does, fallback. + return nil + } + + // Fetch the ChannelUpdate via the real, confirmed SCID. + update, err := s.cfg.FetchLastChannelUpdate(scid) + if err != nil { + return nil + } + + // The incoming case will replace the ShortChannelID in the retrieved + // ChannelUpdate with the alias to ensure no privacy leak occurs. This + // would happen if a private non-zero-conf option-scid-alias + // feature-bit channel leaked its UTXO here rather than supplying an + // alias. In the outgoing case, the confirmed SCID was actually used + // for forwarding in the onion, so no replacement is necessary as the + // sender knows the scid. + if incoming { + // We will replace and sign the update with the first alias. + // Since this happens on the incoming side, it's not actually + // possible to know what the sender used in the onion. + update.ShortChannelID = aliases[0] + sig, err := s.cfg.SignAliasUpdate(update) + if err != nil { + return nil + } + + update.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return nil + } + } + + return update +} + +// AddAliasForLink instructs the Switch to update its in-memory maps to reflect +// that a link has a new alias. +func (s *Switch) AddAliasForLink(chanID lnwire.ChannelID, + alias lnwire.ShortChannelID) error { + + // Fetch the link so that we can update the underlying channel's set of + // aliases. + s.indexMtx.RLock() + link, err := s.getLink(chanID) + s.indexMtx.RUnlock() + if err != nil { + return err + } + + // If the link is a channel where the option-scid-alias feature bit was + // not negotiated, we'll return an error. + if !link.negotiatedAliasFeature() { + return fmt.Errorf("attempted to update non-alias channel") + } + + linkScid := link.ShortChanID() + + // We'll update the maps so the Switch includes this alias in its + // forwarding decisions. + if link.isZeroConf() { + if link.zeroConfConfirmed() { + // If the channel has confirmed on-chain, we'll + // add this alias to the aliasToReal map. + confirmedScid := link.confirmedScid() + + s.aliasToReal[alias] = confirmedScid + } + + // Add this alias to the baseIndex mapping. + s.baseIndex[alias] = linkScid + } else if link.negotiatedAliasFeature() { + // The channel is confirmed, so we'll populate the aliasToReal + // and baseIndex maps. + s.aliasToReal[alias] = linkScid + s.baseIndex[alias] = linkScid + } + + return nil +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index d2d80ee9e..335eb1bc2 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "io/ioutil" + mrand "math/rand" + "os" "reflect" "testing" "time" @@ -26,6 +28,7 @@ import ( ) var zeroCircuit = channeldb.CircuitKey{} +var emptyScid = lnwire.ShortChannelID{} func genPreimage() ([32]byte, error) { var preimage [32]byte @@ -36,8 +39,8 @@ func genPreimage() ([32]byte, error) { } // TestSwitchAddDuplicateLink tests that the switch will reject duplicate links -// for both pending and live links. It also tests that we can successfully -// add a link after having removed it. +// for live links. It also tests that we can successfully add a link after +// having removed it. func TestSwitchAddDuplicateLink(t *testing.T) { t.Parallel() @@ -53,27 +56,16 @@ func TestSwitchAddDuplicateLink(t *testing.T) { } defer s.Stop() - chanID1, _, aliceChanID, _ := genIDs() - - pendingChanID := lnwire.ShortChannelID{} + chanID1, aliceScid := genID() aliceChannelLink := newMockChannelLink( - s, chanID1, pendingChanID, alicePeer, false, + s, chanID1, aliceScid, emptyScid, alicePeer, false, false, + false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } - // Alice should have a pending link, adding again should fail. - if err := s.AddLink(aliceChannelLink); err == nil { - t.Fatalf("adding duplicate link should have failed") - } - - // Update the short chan id of the channel, so that the link goes live. - aliceChannelLink.setLiveShortChanID(aliceChanID) - err = s.UpdateShortChanID(chanID1) - require.NoError(t, err, "unable to update alice short_chan_id") - // Alice should have a live link, adding again should fail. if err := s.AddLink(aliceChannelLink); err == nil { t.Fatalf("adding duplicate link should have failed") @@ -107,12 +99,11 @@ func TestSwitchHasActiveLink(t *testing.T) { } defer s.Stop() - chanID1, _, aliceChanID, _ := genIDs() - - pendingChanID := lnwire.ShortChannelID{} + chanID1, aliceScid := genID() aliceChannelLink := newMockChannelLink( - s, chanID1, pendingChanID, alicePeer, false, + s, chanID1, aliceScid, emptyScid, alicePeer, false, false, + false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -125,24 +116,6 @@ func TestSwitchHasActiveLink(t *testing.T) { t.Fatalf("link should not be active yet, still pending") } - // Update the short chan id of the channel, so that the link goes live. - aliceChannelLink.setLiveShortChanID(aliceChanID) - err = s.UpdateShortChanID(chanID1) - require.NoError(t, err, "unable to update alice short_chan_id") - - // UpdateShortChanID will cause the mock link to become eligible to - // forward. However, we can simulate the event where the short chan id - // is confirmed, but funding locked has yet to be received by resetting - // the mock link's eligibility to false. - aliceChannelLink.eligible = false - - // Now, even though the link has been added to the linkIndex because the - // short channel id has confirmed, we should still see HasActiveLink - // fail because EligibleToForward should return false. - if s.HasActiveLink(chanID1) { - t.Fatalf("link should not be active yet, still ineligible") - } - // Finally, simulate the link receiving funding locked by setting its // eligibility to true. aliceChannelLink.eligible = true @@ -155,7 +128,7 @@ func TestSwitchHasActiveLink(t *testing.T) { } // TestSwitchSendPending checks the inability of htlc switch to forward adds -// over pending links, and the UpdateShortChanID makes a pending link live. +// over pending links. func TestSwitchSendPending(t *testing.T) { t.Parallel() @@ -181,14 +154,16 @@ func TestSwitchSendPending(t *testing.T) { pendingChanID := lnwire.ShortChannelID{} aliceChannelLink := newMockChannelLink( - s, chanID1, pendingChanID, alicePeer, false, + s, chanID1, pendingChanID, emptyScid, alicePeer, false, false, + false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(bobChannelLink); err != nil { t.Fatalf("unable to add bob link: %v", err) @@ -244,30 +219,653 @@ func TestSwitchSendPending(t *testing.T) { if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } +} - // Now, update Alice's link with her final short channel id. This should - // move the link to the live state. - aliceChannelLink.setLiveShortChanID(aliceChanID) - err = s.UpdateShortChanID(chanID1) - require.NoError(t, err, "unable to update alice short_chan_id") +// TestSwitchForwardMapping checks that the Switch properly consults its maps +// when forwarding packets. +func TestSwitchForwardMapping(t *testing.T) { + tests := []struct { + name string - // Increment the packet's HTLC index, so that it does not collide with - // the prior attempt. - packet.incomingHTLCID++ + // If this is true, then Alice's channel will be private. + alicePrivate bool - // Handle the request and checks that bob channel link received it. - if err := s.ForwardPackets(nil, packet); err != nil { - t.Fatalf("unexpected forward failure: %v", err) + // If this is true, then Alice's channel will be a zero-conf + // channel. + zeroConf bool + + // If this is true, then Alice's channel will be an + // option-scid-alias feature-bit, non-zero-conf channel. + optionScid bool + + // If this is true, then an alias will be used for forwarding. + useAlias bool + + // This is Alice's channel alias. This may not be set if this + // is not an option_scid_alias channel (feature bit). + aliceAlias lnwire.ShortChannelID + + // This is Alice's confirmed SCID. This may not be set if this + // is a zero-conf channel before confirmation. + aliceReal lnwire.ShortChannelID + + // If this is set, we expect Bob forwarding to Alice to fail. + expectErr bool + }{ + { + name: "private unconfirmed zero-conf", + alicePrivate: true, + zeroConf: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_002, + TxIndex: 2, + TxPosition: 2, + }, + aliceReal: lnwire.ShortChannelID{}, + expectErr: false, + }, + { + name: "private confirmed zero-conf", + alicePrivate: true, + zeroConf: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_003, + TxIndex: 3, + TxPosition: 3, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 300000, + TxIndex: 3, + TxPosition: 3, + }, + expectErr: false, + }, + { + name: "private confirmed zero-conf failure", + alicePrivate: true, + zeroConf: true, + useAlias: false, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_004, + TxIndex: 4, + TxPosition: 4, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 300002, + TxIndex: 4, + TxPosition: 4, + }, + expectErr: true, + }, + { + name: "public unconfirmed zero-conf", + alicePrivate: false, + zeroConf: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_005, + TxIndex: 5, + TxPosition: 5, + }, + aliceReal: lnwire.ShortChannelID{}, + expectErr: false, + }, + { + name: "public confirmed zero-conf w/ alias", + alicePrivate: false, + zeroConf: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_006, + TxIndex: 6, + TxPosition: 6, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 500000, + TxIndex: 6, + TxPosition: 6, + }, + expectErr: false, + }, + { + name: "public confirmed zero-conf w/ real", + alicePrivate: false, + zeroConf: true, + useAlias: false, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_007, + TxIndex: 7, + TxPosition: 7, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 502000, + TxIndex: 7, + TxPosition: 7, + }, + expectErr: false, + }, + { + name: "private non-option channel", + alicePrivate: true, + aliceAlias: lnwire.ShortChannelID{}, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 505000, + TxIndex: 8, + TxPosition: 8, + }, + }, + { + name: "private option channel w/ alias", + alicePrivate: true, + optionScid: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_015, + TxIndex: 9, + TxPosition: 9, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 506000, + TxIndex: 10, + TxPosition: 10, + }, + expectErr: false, + }, + { + name: "private option channel failure", + alicePrivate: true, + optionScid: true, + useAlias: false, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_016, + TxIndex: 16, + TxPosition: 16, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 507000, + TxIndex: 17, + TxPosition: 17, + }, + expectErr: true, + }, + { + name: "public non-option channel", + alicePrivate: false, + useAlias: false, + aliceAlias: lnwire.ShortChannelID{}, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 508000, + TxIndex: 17, + TxPosition: 17, + }, + expectErr: false, + }, + { + name: "public option channel w/ alias", + alicePrivate: false, + optionScid: true, + useAlias: true, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_018, + TxIndex: 18, + TxPosition: 18, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 509000, + TxIndex: 19, + TxPosition: 19, + }, + expectErr: false, + }, + { + name: "public option channel w/ real", + alicePrivate: false, + optionScid: true, + useAlias: false, + aliceAlias: lnwire.ShortChannelID{ + BlockHeight: 16_000_019, + TxIndex: 19, + TxPosition: 19, + }, + aliceReal: lnwire.ShortChannelID{ + BlockHeight: 510000, + TxIndex: 20, + TxPosition: 20, + }, + expectErr: false, + }, } - // Since Alice's link is now active, this packet should succeed. - select { - case <-aliceChannelLink.packets: - case <-time.After(time.Second): - t.Fatal("request was not propagated to alice") + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + testSwitchForwardMapping( + t, test.alicePrivate, test.zeroConf, + test.useAlias, test.optionScid, + test.aliceAlias, test.aliceReal, + test.expectErr, + ) + }) } } +func testSwitchForwardMapping(t *testing.T, alicePrivate, aliceZeroConf, + useAlias, optionScid bool, aliceAlias, aliceReal lnwire.ShortChannelID, + expectErr bool) { + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + s, err := initSwitchWithDB(testStartingHeight, nil) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + defer func() { _ = s.Stop() }() + + // Create the lnwire.ChannelIDs that we'll use. + chanID1, chanID2, _, _ := genIDs() + + var aliceChannelLink *mockChannelLink + + if aliceZeroConf { + aliceChannelLink = newMockChannelLink( + s, chanID1, aliceAlias, aliceReal, alicePeer, true, + alicePrivate, true, false, + ) + } else { + aliceChannelLink = newMockChannelLink( + s, chanID1, aliceReal, emptyScid, alicePeer, true, + alicePrivate, false, optionScid, + ) + + if optionScid { + aliceChannelLink.addAlias(aliceAlias) + } + } + + err = s.AddLink(aliceChannelLink) + require.NoError(t, err) + + // Bob will just have a non-option_scid_alias channel so no mapping is + // necessary. + bobScid := lnwire.ShortChannelID{ + BlockHeight: 501000, + TxIndex: 200, + TxPosition: 2, + } + + bobChannelLink := newMockChannelLink( + s, chanID2, bobScid, emptyScid, bobPeer, true, false, false, + false, + ) + err = s.AddLink(bobChannelLink) + require.NoError(t, err) + + // Generate preimage. + preimage, err := genPreimage() + require.NoError(t, err, "unable to generate preimage") + rhash := sha256.Sum256(preimage[:]) + + // Determine the outgoing SCID to use. + outgoingSCID := aliceReal + if useAlias { + outgoingSCID = aliceAlias + } + + packet := &htlcPacket{ + incomingChanID: bobScid, + incomingHTLCID: 0, + outgoingChanID: outgoingSCID, + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + err = s.ForwardPackets(nil, packet) + require.NoError(t, err) + + // If we expect a forwarding error, then assert that we receive one. + // option_scid_alias forwards may fail if forwarding would be a privacy + // leak. + if expectErr { + select { + case <-bobChannelLink.packets: + case <-time.After(time.Second * 5): + t.Fatal("expected a forwarding error") + } + + select { + case <-aliceChannelLink.packets: + t.Fatal("did not expect a packet") + case <-time.After(time.Second * 5): + } + } else { + select { + case <-bobChannelLink.packets: + t.Fatal("did not expect a forwarding error") + case <-time.After(time.Second * 5): + } + + select { + case <-aliceChannelLink.packets: + case <-time.After(time.Second * 5): + t.Fatal("expected alice to receive packet") + } + } +} + +// TestSwitchSendHTLCMapping tests that SendHTLC will properly route packets to +// zero-conf or option-scid-alias (feature-bit) channels if the confirmed SCID +// is used. It also tests that nothing breaks with the mapping change. +func TestSwitchSendHTLCMapping(t *testing.T) { + tests := []struct { + name string + + // If this is true, the channel will be zero-conf. + zeroConf bool + + // Denotes whether the channel is option-scid-alias, non + // zero-conf feature bit. + optionFeature bool + + // If this is true, then the alias will be used in the packet. + useAlias bool + + // This will be the channel alias if there is a mapping. + alias lnwire.ShortChannelID + + // This will be the confirmed SCID if the channel is confirmed. + real lnwire.ShortChannelID + }{ + { + name: "non-zero-conf real scid w/ option", + zeroConf: false, + optionFeature: true, + useAlias: false, + alias: lnwire.ShortChannelID{ + BlockHeight: 10010, + TxIndex: 10, + TxPosition: 10, + }, + real: lnwire.ShortChannelID{ + BlockHeight: 500000, + TxIndex: 50, + TxPosition: 50, + }, + }, + { + name: "non-zero-conf real scid no option", + zeroConf: false, + useAlias: false, + alias: lnwire.ShortChannelID{}, + real: lnwire.ShortChannelID{ + BlockHeight: 400000, + TxIndex: 50, + TxPosition: 50, + }, + }, + { + name: "zero-conf alias scid w/ conf", + zeroConf: true, + useAlias: true, + alias: lnwire.ShortChannelID{ + BlockHeight: 10020, + TxIndex: 20, + TxPosition: 20, + }, + real: lnwire.ShortChannelID{ + BlockHeight: 450000, + TxIndex: 50, + TxPosition: 50, + }, + }, + { + name: "zero-conf alias scid no conf", + zeroConf: true, + useAlias: true, + alias: lnwire.ShortChannelID{ + BlockHeight: 10015, + TxIndex: 25, + TxPosition: 35, + }, + real: lnwire.ShortChannelID{}, + }, + { + name: "zero-conf real scid", + zeroConf: true, + useAlias: false, + alias: lnwire.ShortChannelID{ + BlockHeight: 10035, + TxIndex: 35, + TxPosition: 35, + }, + real: lnwire.ShortChannelID{ + BlockHeight: 470000, + TxIndex: 35, + TxPosition: 45, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + testSwitchSendHtlcMapping( + t, test.zeroConf, test.useAlias, test.alias, + test.real, test.optionFeature, + ) + }) + } +} + +func testSwitchSendHtlcMapping(t *testing.T, zeroConf, useAlias bool, alias, + realScid lnwire.ShortChannelID, optionFeature bool) { + + peer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + s, err := initSwitchWithDB(testStartingHeight, nil) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + defer func() { _ = s.Stop() }() + + // Create the lnwire.ChannelID that we'll use. + chanID, _ := genID() + + var link *mockChannelLink + + if zeroConf { + link = newMockChannelLink( + s, chanID, alias, realScid, peer, true, false, true, + false, + ) + } else { + link = newMockChannelLink( + s, chanID, realScid, emptyScid, peer, true, false, + false, true, + ) + + if optionFeature { + link.addAlias(alias) + } + } + + err = s.AddLink(link) + require.NoError(t, err) + + // Generate preimage. + preimage, err := genPreimage() + require.NoError(t, err) + rhash := sha256.Sum256(preimage[:]) + + // Determine the outgoing SCID to use. + outgoingSCID := realScid + if useAlias { + outgoingSCID = alias + } + + // Send the HTLC and assert that we don't get an error. + htlc := &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + } + + err = s.SendHTLC(outgoingSCID, 0, htlc) + require.NoError(t, err) +} + +// TestSwitchUpdateScid verifies that zero-conf and non-zero-conf +// option-scid-alias (feature bit) channels will have the expected entries in +// the aliasToReal and baseIndex maps. +func TestSwitchUpdateScid(t *testing.T) { + t.Parallel() + + peer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err, "unable to create alice server") + + s, err := initSwitchWithDB(testStartingHeight, nil) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + defer func() { _ = s.Stop() }() + + // Create the IDs that we'll use. + chanID, chanID2, _, _ := genIDs() + + alias := lnwire.ShortChannelID{ + BlockHeight: 16_000_000, + TxIndex: 0, + TxPosition: 0, + } + alias2 := alias + alias2.TxPosition = 1 + + realScid := lnwire.ShortChannelID{ + BlockHeight: 500000, + TxIndex: 0, + TxPosition: 0, + } + + link := newMockChannelLink( + s, chanID, alias, emptyScid, peer, true, false, true, false, + ) + link.addAlias(alias2) + + err = s.AddLink(link) + require.NoError(t, err) + + // Assert that the zero-conf link does not have entries in the + // aliasToReal map. + s.indexMtx.RLock() + _, ok := s.aliasToReal[alias] + require.False(t, ok) + _, ok = s.aliasToReal[alias2] + require.False(t, ok) + + // Assert that both aliases point to the "base" SCID, which is actually + // just the first alias. + baseScid, ok := s.baseIndex[alias] + require.True(t, ok) + require.Equal(t, alias, baseScid) + + baseScid, ok = s.baseIndex[alias2] + require.True(t, ok) + require.Equal(t, alias, baseScid) + + s.indexMtx.RUnlock() + + // We'll set the mock link's confirmed SCID so that UpdateShortChanID + // populates aliasToReal and adds an entry to baseIndex. + link.realScid = realScid + link.confirmedZC = true + + err = s.UpdateShortChanID(chanID) + require.NoError(t, err) + + // Assert that aliasToReal is populated and there is an entry in + // baseIndex for realScid. + s.indexMtx.RLock() + realMapping, ok := s.aliasToReal[alias] + require.True(t, ok) + require.Equal(t, realScid, realMapping) + + realMapping, ok = s.aliasToReal[alias2] + require.True(t, ok) + require.Equal(t, realScid, realMapping) + + baseScid, ok = s.baseIndex[realScid] + require.True(t, ok) + require.Equal(t, alias, baseScid) + + s.indexMtx.RUnlock() + + // Now we'll perform the same checks with a non-zero-conf + // option-scid-alias channel (feature-bit). + optionReal := lnwire.ShortChannelID{ + BlockHeight: 600000, + TxIndex: 0, + TxPosition: 0, + } + optionAlias := lnwire.ShortChannelID{ + BlockHeight: 12000, + TxIndex: 0, + TxPosition: 0, + } + optionAlias2 := optionAlias + optionAlias2.TxPosition = 1 + link2 := newMockChannelLink( + s, chanID2, optionReal, emptyScid, peer, true, false, false, + true, + ) + link2.addAlias(optionAlias) + link2.addAlias(optionAlias2) + + err = s.AddLink(link2) + require.NoError(t, err) + + // Assert that the option-scid-alias link does have entries in the + // aliasToReal and baseIndex maps. + s.indexMtx.RLock() + realMapping, ok = s.aliasToReal[optionAlias] + require.True(t, ok) + require.Equal(t, optionReal, realMapping) + + realMapping, ok = s.aliasToReal[optionAlias2] + require.True(t, ok) + require.Equal(t, optionReal, realMapping) + + baseScid, ok = s.baseIndex[optionReal] + require.True(t, ok) + require.Equal(t, optionReal, baseScid) + + baseScid, ok = s.baseIndex[optionAlias] + require.True(t, ok) + require.Equal(t, optionReal, baseScid) + + baseScid, ok = s.baseIndex[optionAlias2] + require.True(t, ok) + require.Equal(t, optionReal, baseScid) + + s.indexMtx.RUnlock() +} + // TestSwitchForward checks the ability of htlc switch to forward add/settle // requests. func TestSwitchForward(t *testing.T) { @@ -276,14 +874,20 @@ func TestSwitchForward(t *testing.T) { alicePeer, err := newMockServer( t, "alice", testStartingHeight, nil, testDefaultDelta, ) - require.NoError(t, err, "unable to create alice server") + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } bobPeer, err := newMockServer( t, "bob", testStartingHeight, nil, testDefaultDelta, ) - require.NoError(t, err, "unable to create bob server") + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } s, err := initSwitchWithDB(testStartingHeight, nil) - require.NoError(t, err, "unable to init switch") + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } if err := s.Start(); err != nil { t.Fatalf("unable to start switch: %v", err) } @@ -292,10 +896,12 @@ func TestSwitchForward(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -307,7 +913,9 @@ func TestSwitchForward(t *testing.T) { // Create request which should be forwarded from Alice channel link to // bob channel link. preimage, err := genPreimage() - require.NoError(t, err, "unable to generate preimage") + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := sha256.Sum256(preimage[:]) packet := &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), @@ -381,7 +989,9 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { alicePeer, err := newMockServer( t, "alice", testStartingHeight, nil, testDefaultDelta, ) - require.NoError(t, err, "unable to create alice server") + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } bobPeer, err := newMockServer( t, "bob", testStartingHeight, nil, testDefaultDelta, ) @@ -405,10 +1015,12 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -496,10 +1108,12 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { defer s2.Stop() aliceChannelLink = newMockChannelLink( - s2, chanID1, aliceChanID, alicePeer, true, + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s2, chanID2, bobChanID, bobPeer, true, + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s2.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -590,10 +1204,12 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -681,10 +1297,12 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { defer s2.Stop() aliceChannelLink = newMockChannelLink( - s2, chanID1, aliceChanID, alicePeer, true, + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s2, chanID2, bobChanID, bobPeer, true, + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s2.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -778,10 +1396,12 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -861,10 +1481,12 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { defer s2.Stop() aliceChannelLink = newMockChannelLink( - s2, chanID1, aliceChanID, alicePeer, true, + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s2, chanID2, bobChanID, bobPeer, true, + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s2.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -929,10 +1551,12 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1007,10 +1631,12 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { defer s2.Stop() aliceChannelLink = newMockChannelLink( - s2, chanID1, aliceChanID, alicePeer, true, + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s2, chanID2, bobChanID, bobPeer, true, + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s2.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1081,10 +1707,12 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1158,10 +1786,12 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { defer s2.Stop() aliceChannelLink = newMockChannelLink( - s2, chanID1, aliceChanID, alicePeer, true, + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s2, chanID2, bobChanID, bobPeer, true, + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s2.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1243,10 +1873,12 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { defer s3.Stop() aliceChannelLink = newMockChannelLink( - s3, chanID1, aliceChanID, alicePeer, true, + s3, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink = newMockChannelLink( - s3, chanID2, bobChanID, bobPeer, true, + s3, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s3.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1326,7 +1958,8 @@ func TestCircularForwards(t *testing.T) { s.cfg.AllowCircularRoute = test.allowCircularPayment aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, + true, false, false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { @@ -1378,6 +2011,10 @@ func TestCheckCircularForward(t *testing.T) { tests := []struct { name string + // aliasMapping determines whether the test should add an alias + // mapping to Switch alias maps before checkCircularForward. + aliasMapping bool + // allowCircular determines whether we should allow circular // forwards. allowCircular bool @@ -1394,6 +2031,7 @@ func TestCheckCircularForward(t *testing.T) { }{ { name: "not circular, allowed in config", + aliasMapping: false, allowCircular: true, incomingLink: lnwire.NewShortChanIDFromInt(123), outgoingLink: lnwire.NewShortChanIDFromInt(321), @@ -1401,6 +2039,7 @@ func TestCheckCircularForward(t *testing.T) { }, { name: "not circular, not allowed in config", + aliasMapping: false, allowCircular: false, incomingLink: lnwire.NewShortChanIDFromInt(123), outgoingLink: lnwire.NewShortChanIDFromInt(321), @@ -1408,6 +2047,7 @@ func TestCheckCircularForward(t *testing.T) { }, { name: "circular, allowed in config", + aliasMapping: false, allowCircular: true, incomingLink: lnwire.NewShortChanIDFromInt(123), outgoingLink: lnwire.NewShortChanIDFromInt(123), @@ -1415,6 +2055,7 @@ func TestCheckCircularForward(t *testing.T) { }, { name: "circular, not allowed in config", + aliasMapping: false, allowCircular: false, incomingLink: lnwire.NewShortChanIDFromInt(123), outgoingLink: lnwire.NewShortChanIDFromInt(123), @@ -1423,6 +2064,52 @@ func TestCheckCircularForward(t *testing.T) { OutgoingFailureCircularRoute, ), }, + { + name: "circular with map, not allowed", + aliasMapping: true, + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), + outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55), + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + OutgoingFailureCircularRoute, + ), + }, + { + name: "circular with map, not allowed 2", + aliasMapping: true, + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(1 << 55), + outgoingLink: lnwire.NewShortChanIDFromInt(1 << 60), + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + OutgoingFailureCircularRoute, + ), + }, + { + name: "circular with map, allowed", + aliasMapping: true, + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), + outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55), + expectedErr: nil, + }, + { + name: "circular with map, allowed 2", + aliasMapping: true, + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(1 << 55), + outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61), + expectedErr: nil, + }, + { + name: "not circular, both confirmed SCID", + aliasMapping: false, + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), + outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61), + expectedErr: nil, + }, } for _, test := range tests { @@ -1431,9 +2118,26 @@ func TestCheckCircularForward(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + s, err := initSwitchWithDB(testStartingHeight, nil) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + defer func() { _ = s.Stop() }() + + if test.aliasMapping { + // Make the incoming and outgoing point to the + // same base SCID. + inScid := test.incomingLink + outScid := test.outgoingLink + s.indexMtx.Lock() + s.baseIndex[inScid] = outScid + s.baseIndex[outScid] = outScid + s.indexMtx.Unlock() + } + // Check for a circular forward, the hash passed can // be nil because it is only used for logging. - err := checkCircularForward( + err = s.checkCircularForward( test.incomingLink, test.outgoingLink, test.allowCircular, lntypes.Hash{}, ) @@ -1528,20 +2232,23 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T, chanID1, aliceChanID := genID() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) // We'll create a link for Bob, but mark the link as unable to forward // any new outgoing HTLC's. chanID2, bobChanID2 := genID() bobChannelLink1 := newMockChannelLink( - s, chanID2, bobChanID2, bobPeer, testCase.eligible1, + s, chanID2, bobChanID2, emptyScid, bobPeer, testCase.eligible1, + false, false, false, ) bobChannelLink1.checkHtlcForwardResult = testCase.failure1 chanID3, bobChanID3 := genID() bobChannelLink2 := newMockChannelLink( - s, chanID3, bobChanID3, bobPeer, testCase.eligible2, + s, chanID3, bobChanID3, emptyScid, bobPeer, testCase.eligible2, + false, false, false, ) bobChannelLink2.checkHtlcForwardResult = testCase.failure2 @@ -1649,7 +2356,8 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, chanID1, _, aliceChanID, _ := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, eligible, + s, chanID1, aliceChanID, emptyScid, alicePeer, eligible, false, + false, false, ) aliceChannelLink.checkHtlcTransitResult = NewLinkError( policyResult, @@ -1703,10 +2411,12 @@ func TestSwitchCancel(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1810,10 +2520,12 @@ func TestSwitchAddSamePayment(t *testing.T) { defer s.Stop() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -1963,7 +2675,8 @@ func TestSwitchSendPayment(t *testing.T) { chanID1, _, aliceChanID, _ := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add link: %v", err) @@ -2483,7 +3196,8 @@ func TestInvalidFailure(t *testing.T) { // Set up a mock channel link. aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add link: %v", err) @@ -3076,10 +3790,12 @@ func TestSwitchHoldForward(t *testing.T) { }() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) @@ -3706,19 +4422,22 @@ func TestSwitchMailboxDust(t *testing.T) { chanID3, carolChanID := genID() aliceLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) err = s.AddLink(aliceLink) require.NoError(t, err) bobLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) err = s.AddLink(bobLink) require.NoError(t, err) carolLink := newMockChannelLink( - s, chanID3, carolChanID, carolPeer, true, + s, chanID3, carolChanID, emptyScid, carolPeer, true, false, + false, false, ) err = s.AddLink(carolLink) require.NoError(t, err) @@ -3823,10 +4542,12 @@ func TestSwitchResolution(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, + false, false, ) bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, true, + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, ) err = s.AddLink(aliceChannelLink) require.NoError(t, err) @@ -3939,3 +4660,708 @@ func TestSwitchResolution(t *testing.T) { require.NoError(t, err) require.Equal(t, 0, len(resMsgs)) } + +// TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure +// before actually forwarding, the ChannelUpdate uses the SCID from the +// incoming channel and does not leak private information like the UTXO. +func TestSwitchForwardFailAlias(t *testing.T) { + tests := []struct { + name string + + // Whether or not Alice will be a zero-conf channel or an + // option-scid-alias channel (feature-bit). + zeroConf bool + }{ + { + name: "option-scid-alias forwarding failure", + zeroConf: false, + }, + { + name: "zero-conf forwarding failure", + zeroConf: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + testSwitchForwardFailAlias(t, test.zeroConf) + }) + } +} + +func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + tempPath, err := ioutil.TempDir("", "circuitdb") + require.NoError(t, err) + + cdb, err := channeldb.Open(tempPath) + require.NoError(t, err) + + s, err := initSwitchWithDB(testStartingHeight, cdb) + require.NoError(t, err) + + err = s.Start() + require.NoError(t, err) + + // Make Alice's channel zero-conf or option-scid-alias (feature bit). + aliceAlias := lnwire.ShortChannelID{ + BlockHeight: 16_000_000, + TxIndex: 5, + TxPosition: 5, + } + + var aliceLink *mockChannelLink + if zeroConf { + aliceLink = newMockChannelLink( + s, chanID1, aliceAlias, aliceChanID, alicePeer, true, + true, true, false, + ) + } else { + aliceLink = newMockChannelLink( + s, chanID1, aliceChanID, emptyScid, alicePeer, true, + true, false, true, + ) + aliceLink.addAlias(aliceAlias) + } + err = s.AddLink(aliceLink) + require.NoError(t, err) + + bobLink := newMockChannelLink( + s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, + ) + err = s.AddLink(bobLink) + require.NoError(t, err) + + // Create a packet that will be sent from Alice to Bob via the switch. + preimage := [sha256.Size]byte{1} + rhash := sha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + // Forward the packet and check that Bob's channel link received it. + err = s.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + // Assert that the circuits are in the expected state. + require.Equal(t, 1, s.circuits.NumPending()) + require.Equal(t, 0, s.circuits.NumOpen()) + + // Pull packet from Bob's link, and do nothing with it. + select { + case <-bobLink.packets: + case <-s.quit: + t.Fatal("switch shutting down, failed to forward packet") + } + + // Now we will restart the Switch to trigger the LoadedFromDisk logic. + err = s.Stop() + require.NoError(t, err) + + err = cdb.Close() + require.NoError(t, err) + + cdb2, err := channeldb.Open(tempPath) + require.NoError(t, err) + + s2, err := initSwitchWithDB(testStartingHeight, cdb2) + require.NoError(t, err) + + err = s2.Start() + require.NoError(t, err) + + defer func() { + _ = s2.Stop() + _ = os.RemoveAll(tempPath) + }() + + var aliceLink2 *mockChannelLink + if zeroConf { + aliceLink2 = newMockChannelLink( + s2, chanID1, aliceAlias, aliceChanID, alicePeer, true, + true, true, false, + ) + } else { + aliceLink2 = newMockChannelLink( + s2, chanID1, aliceChanID, emptyScid, alicePeer, true, + true, false, true, + ) + aliceLink2.addAlias(aliceAlias) + } + err = s2.AddLink(aliceLink2) + require.NoError(t, err) + + bobLink2 := newMockChannelLink( + s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, + false, + ) + err = s2.AddLink(bobLink2) + require.NoError(t, err) + + // Reforward the ogPacket and wait for Alice to receive a failure + // packet. + err = s2.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + select { + case failPacket := <-aliceLink2.packets: + // Assert that the failPacket does not leak UTXO information. + // This means checking that aliceChanID was not returned. + msg := failPacket.linkFailure.msg + failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) + require.True(t, ok) + require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) + case <-s2.quit: + t.Fatal("switch shutting down, failed to forward packet") + } +} + +// TestSwitchAliasFailAdd tests that the mailbox does not leak UTXO information +// when failing back an HTLC due to the 5-second timeout. This is tested in the +// switch rather than the mailbox because the mailbox tests do not have the +// proper context (e.g. the Switch's failAliasUpdate function). The caveat here +// is that if the private UTXO is already known, it is fine to send a failure +// back. This tests option-scid-alias (feature-bit) and zero-conf channels. +func TestSwitchAliasFailAdd(t *testing.T) { + tests := []struct { + name string + + // Denotes whether the opened channel will be zero-conf. + zeroConf bool + + // Denotes whether the opened channel will be private. + private bool + + // Denotes whether an alias was used during forwarding. + useAlias bool + }{ + { + name: "public zero-conf using alias", + zeroConf: true, + private: false, + useAlias: true, + }, + { + name: "public zero-conf using real", + zeroConf: true, + private: false, + useAlias: true, + }, + { + name: "private zero-conf using alias", + zeroConf: true, + private: true, + useAlias: true, + }, + { + name: "public option-scid-alias using alias", + zeroConf: false, + private: false, + useAlias: true, + }, + { + name: "public option-scid-alias using real", + zeroConf: false, + private: false, + useAlias: false, + }, + { + name: "private option-scid-alias using alias", + zeroConf: false, + private: true, + useAlias: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + testSwitchAliasFailAdd( + t, test.zeroConf, test.private, test.useAlias, + ) + }) + } +} + +func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + tempPath, err := ioutil.TempDir("", "circuitdb") + require.NoError(t, err) + + cdb, err := channeldb.Open(tempPath) + require.NoError(t, err) + + s, err := initSwitchWithDB(testStartingHeight, cdb) + require.NoError(t, err) + + // Change the mailOrchestrator's expiry to a second. + s.mailOrchestrator.cfg.expiry = time.Second + + err = s.Start() + require.NoError(t, err) + + defer func() { + _ = s.Stop() + _ = os.RemoveAll(tempPath) + }() + + // Make Alice's channel zero-conf or option-scid-alias (feature bit). + aliceAlias := lnwire.ShortChannelID{ + BlockHeight: 16_000_000, + TxIndex: 5, + TxPosition: 5, + } + aliceAlias2 := aliceAlias + aliceAlias2.TxPosition = 6 + + var aliceLink *mockChannelLink + if zeroConf { + aliceLink = newMockChannelLink( + s, chanID1, aliceAlias, aliceChanID, alicePeer, true, + private, true, false, + ) + aliceLink.addAlias(aliceAlias2) + } else { + aliceLink = newMockChannelLink( + s, chanID1, aliceChanID, emptyScid, alicePeer, true, + private, false, true, + ) + aliceLink.addAlias(aliceAlias) + aliceLink.addAlias(aliceAlias2) + } + err = s.AddLink(aliceLink) + require.NoError(t, err) + + bobLink := newMockChannelLink( + s, chanID2, bobChanID, emptyScid, bobPeer, true, true, false, + false, + ) + err = s.AddLink(bobLink) + require.NoError(t, err) + + // Create a packet that Bob will send to Alice via ForwardPackets. + preimage := [sha256.Size]byte{1} + rhash := sha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: bobLink.ShortChanID(), + incomingHTLCID: 0, + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + // Determine which outgoingChanID to set based on the useAlias boolean. + outgoingChanID := aliceChanID + if useAlias { + // Choose randomly from the 2 possible aliases. + aliases := aliceLink.getAliases() + idx := mrand.Intn(len(aliases)) + + outgoingChanID = aliases[idx] + } + + ogPacket.outgoingChanID = outgoingChanID + + // Forward the packet so Alice's mailbox fails it backwards. + err = s.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + // Assert that the circuits are in the expected state. + require.Equal(t, 1, s.circuits.NumPending()) + require.Equal(t, 0, s.circuits.NumOpen()) + + // Wait to receive the packet from Bob's mailbox. + select { + case failPacket := <-bobLink.packets: + // Assert that failPacket returns the expected SCID in the + // ChannelUpdate. + msg := failPacket.linkFailure.msg + failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) + require.True(t, ok) + require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + case <-s.quit: + t.Fatal("switch shutting down, failed to receive fail packet") + } +} + +// TestSwitchHandlePacketForwardAlias checks that handlePacketForward (which +// calls CheckHtlcForward) does not leak the UTXO in a failure message for +// alias channels. This test requires us to have a REAL link, which we also +// must modify in order to test it properly (e.g. making it a private channel). +// This doesn't lead to good code, but short of refactoring the link-generation +// code there is not a good alternative. +func TestSwitchHandlePacketForward(t *testing.T) { + tests := []struct { + name string + + // Denotes whether or not the channel will be zero-conf. + zeroConf bool + + // Denotes whether or not the channel will have negotiated the + // option-scid-alias feature-bit and is not zero-conf. + optionFeature bool + + // Denotes whether or not the channel will be private. + private bool + + // Denotes whether or not the alias will be used for + // forwarding. + useAlias bool + }{ + { + name: "public zero-conf using alias", + zeroConf: true, + private: false, + useAlias: true, + }, + { + name: "public zero-conf using real", + zeroConf: true, + private: false, + useAlias: false, + }, + { + name: "private zero-conf using alias", + zeroConf: true, + private: true, + useAlias: true, + }, + { + name: "public option-scid-alias using alias", + zeroConf: false, + optionFeature: true, + private: false, + useAlias: true, + }, + { + name: "public option-scid-alias using real", + zeroConf: false, + optionFeature: true, + private: false, + useAlias: false, + }, + { + name: "private option-scid-alias using alias", + zeroConf: false, + optionFeature: true, + private: true, + useAlias: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + testSwitchHandlePacketForward( + t, test.zeroConf, test.private, test.useAlias, + test.optionFeature, + ) + }) + } +} + +func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, + useAlias, optionFeature bool) { + + t.Parallel() + + // Create a link for Alice that we'll add to the switch. + aliceLink, _, _, _, cleanUp, _, err := + newSingleLinkTestHarness(btcutil.SatoshiPerBitcoin, 0) + require.NoError(t, err) + defer cleanUp() + + s, err := initSwitchWithDB(testStartingHeight, nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer func() { + _ = s.Stop() + }() + + // Change Alice's ShortChanID and OtherShortChanID here. + aliceAlias := lnwire.ShortChannelID{ + BlockHeight: 16_000_000, + TxIndex: 5, + TxPosition: 5, + } + aliceAlias2 := aliceAlias + aliceAlias2.TxPosition = 6 + + aliceChannelLink := aliceLink.(*channelLink) + aliceChannelState := aliceChannelLink.channel.State() + + // Set the link's GetAliases function. + aliceChannelLink.cfg.GetAliases = func( + base lnwire.ShortChannelID) []lnwire.ShortChannelID { + + return []lnwire.ShortChannelID{aliceAlias, aliceAlias2} + } + + if !private { + // Change the channel to public depending on the test. + aliceChannelState.ChannelFlags = lnwire.FFAnnounceChannel + } + + // If this is an option-scid-alias feature-bit non-zero-conf channel, + // we'll mark the channel as such. + if optionFeature { + aliceChannelState.ChanType |= channeldb.ScidAliasFeatureBit + } + + // This is the ShortChannelID field in the OpenChannel struct. + aliceScid := aliceLink.ShortChanID() + if zeroConf { + // Store the alias in the shortChanID field and mark the real + // scid in the database. + aliceChannelLink.shortChanID = aliceAlias + err = aliceChannelState.MarkRealScid(aliceScid) + require.NoError(t, err) + + aliceChannelState.ChanType |= channeldb.ZeroConfBit + } + + err = s.AddLink(aliceLink) + require.NoError(t, err) + + // Add a mockChannelLink for Bob. + bobChanID, bobScid := genID() + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + bobLink := newMockChannelLink( + s, bobChanID, bobScid, emptyScid, bobPeer, true, false, false, + false, + ) + err = s.AddLink(bobLink) + require.NoError(t, err) + + preimage := [sha256.Size]byte{1} + rhash := sha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: bobLink.ShortChanID(), + incomingHTLCID: 0, + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + // Determine which outgoingChanID to set based on the useAlias bool. + outgoingChanID := aliceScid + if useAlias { + // Choose from the possible aliases. + aliases := aliceLink.getAliases() + idx := mrand.Intn(len(aliases)) + + outgoingChanID = aliases[idx] + } + + ogPacket.outgoingChanID = outgoingChanID + + // Forward the packet to Alice and she should fail it back with an + // AmountBelowMinimum FailureMessage. + err = s.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + select { + case failPacket := <-bobLink.packets: + // Assert that failPacket returns the expected ChannelUpdate. + msg := failPacket.linkFailure.msg + failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) + require.True(t, ok) + require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + case <-s.quit: + t.Fatal("switch shutting down, failed to receive failure") + } +} + +// TestSwitchAliasInterceptFail tests that when the InterceptableSwitch fails +// an incoming HTLC, it does not leak the on-chain UTXO for option-scid-alias +// (feature bit) or zero-conf channels. +func TestSwitchAliasInterceptFail(t *testing.T) { + tests := []struct { + name string + + // Denotes whether or not the incoming channel is a zero-conf + // channel or an option-scid-alias channel instead (feature + // bit). + zeroConf bool + }{ + { + name: "option-scid-alias", + zeroConf: false, + }, + { + name: "zero-conf", + zeroConf: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + testSwitchAliasInterceptFail(t, test.zeroConf) + }) + } +} + +func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { + t.Parallel() + + chanID, aliceScid := genID() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + require.NoError(t, err) + + tempPath, err := ioutil.TempDir("", "circuitdb") + require.NoError(t, err) + + cdb, err := channeldb.Open(tempPath) + require.NoError(t, err) + + s, err := initSwitchWithDB(testStartingHeight, cdb) + require.NoError(t, err) + + err = s.Start() + require.NoError(t, err) + + defer func() { + _ = s.Stop() + _ = os.RemoveAll(tempPath) + }() + + // Make Alice's alias here. + aliceAlias := lnwire.ShortChannelID{ + BlockHeight: 16_000_000, + TxIndex: 5, + TxPosition: 5, + } + aliceAlias2 := aliceAlias + aliceAlias2.TxPosition = 6 + + var aliceLink *mockChannelLink + if zeroConf { + aliceLink = newMockChannelLink( + s, chanID, aliceAlias, aliceScid, alicePeer, true, + true, true, false, + ) + aliceLink.addAlias(aliceAlias2) + } else { + aliceLink = newMockChannelLink( + s, chanID, aliceScid, emptyScid, alicePeer, true, + true, false, true, + ) + aliceLink.addAlias(aliceAlias) + aliceLink.addAlias(aliceAlias2) + } + err = s.AddLink(aliceLink) + require.NoError(t, err) + + // Now we'll create the packet that will be sent from the Alice link. + preimage := [sha256.Size]byte{1} + rhash := sha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceLink.ShortChanID(), + incomingTimeout: 1000, + incomingHTLCID: 0, + outgoingChanID: lnwire.ShortChannelID{}, + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + // Now setup the interceptable switch so that we can reject this + // packet. + forwardInterceptor := &mockForwardInterceptor{ + t: t, + interceptedChan: make(chan InterceptedPacket), + } + interceptSwitch := NewInterceptableSwitch(s, 0, false) + require.NoError(t, interceptSwitch.Start()) + interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) + + err = interceptSwitch.ForwardPackets(nil, false, ogPacket) + require.NoError(t, err) + + inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit + require.NoError(t, interceptSwitch.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: inCircuit, + FailureCode: lnwire.CodeTemporaryChannelFailure, + })) + + select { + case failPacket := <-aliceLink.packets: + // Assert that failPacket returns the expected ChannelUpdate. + failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC) + require.True(t, ok) + + r := bytes.NewReader(failHtlc.Reason) + failure, err := lnwire.DecodeFailure(r, 0) + require.NoError(t, err) + + failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) + require.True(t, ok) + + failScid := failureMsg.Update.ShortChannelID + isAlias := failScid == aliceAlias || failScid == aliceAlias2 + require.True(t, isAlias) + + case <-s.quit: + t.Fatalf("switch shutting down, failed to receive failure") + } + + require.NoError(t, interceptSwitch.Stop()) +} diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 7000c119e..75eeeee69 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1128,6 +1128,12 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, return nil } + getAliases := func( + base lnwire.ShortChannelID) []lnwire.ShortChannelID { + + return nil + } + link := NewChannelLink( ChannelLinkConfig{ Switch: server.htlcSwitch, @@ -1168,6 +1174,7 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, NotifyActiveChannel: func(wire.OutPoint) {}, NotifyInactiveChannel: func(wire.OutPoint) {}, HtlcNotifier: server.htlcSwitch.cfg.HtlcNotifier, + GetAliases: getAliases, }, channel, ) diff --git a/server.go b/server.go index 0f79612b9..0ff753949 100644 --- a/server.go +++ b/server.go @@ -657,6 +657,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Clock: clock.NewDefaultClock(), HTLCExpiry: htlcswitch.DefaultHTLCExpiry, DustThreshold: thresholdMSats, + SignAliasUpdate: s.signAliasUpdate, + IsAlias: aliasmgr.IsAlias, }, uint32(currentHeight)) if err != nil { return nil, err