diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index c1fd4b94e..b1c15456c 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1100,15 +1100,16 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { // Generate payment invoice and htlc, but don't add this invoice to the // receiver registry. This should trigger an unknown payment hash // failure. - _, htlc, err := generatePayment(amount, htlcAmt, totalTimelock, - blob) + _, htlc, pid, err := generatePayment( + amount, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatal(err) } // Send payment and expose err channel. _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, + n.firstBobChannelLink.ShortChanID(), pid, htlc, newMockDeobfuscator(), ) if !strings.Contains(err.Error(), lnwire.CodeUnknownPaymentHash.String()) { @@ -1909,7 +1910,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { // a switch initiated payment. The resulting bandwidth should // now be decremented to reflect the new HTLC. htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err := generatePayment( + htlcAmt, htlcAmt, 5, mockBlob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -1989,7 +1992,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { // Next, we'll add another HTLC initiated by the switch (of the same // amount as the prior one). - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err = generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2075,8 +2078,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if err != nil { t.Fatalf("unable to gen route: %v", err) } - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, - totalTimelock, blob) + invoice, htlc, _, err = generatePayment( + htlcAmt, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2183,7 +2187,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if err != nil { t.Fatalf("unable to gen route: %v", err) } - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, totalTimelock, blob) + invoice, htlc, _, err = generatePayment( + htlcAmt, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2314,7 +2320,9 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) { var htlcID uint64 addLinkHTLC := func(id uint64, amt lnwire.MilliSatoshi) [32]byte { - invoice, htlc, err := generatePayment(amt, amt, 5, mockBlob) + invoice, htlc, _, err := generatePayment( + amt, amt, 5, mockBlob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2580,7 +2588,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) { // message for the test. var mockBlob [lnwire.OnionPacketSize]byte htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - _, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + _, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2860,7 +2868,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) { // message for the test. var mockBlob [lnwire.OnionPacketSize]byte htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - _, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + _, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -3113,7 +3121,7 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) { // a switch initiated payment. The resulting bandwidth should // now be decremented to reflect the new HTLC. htlcAmt := lnwire.NewMSatFromSatoshis(3 * btcutil.SatoshiPerBitcoin) - invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -3844,8 +3852,9 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { if err != nil { t.Fatal(err) } - invoice, htlc, err := generatePayment(amount, htlcAmt, totalTimelock, - blob) + invoice, htlc, pid, err := generatePayment( + amount, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatal(err) } @@ -3859,7 +3868,7 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // payment. It should succeed w/o any issues as it has been crafted // properly. _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, + n.firstBobChannelLink.ShortChanID(), pid, htlc, newMockDeobfuscator(), ) if err != nil { @@ -3869,7 +3878,7 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // Now, if we attempt to send the payment *again* it should be rejected // as it's a duplicate request. _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, + n.firstBobChannelLink.ShortChanID(), pid, htlc, newMockDeobfuscator(), ) if err != ErrAlreadyPaid { @@ -4255,7 +4264,7 @@ func generateHtlcAndInvoice(t *testing.T, t.Fatalf("unable to generate route: %v", err) } - invoice, htlc, err := generatePayment( + invoice, htlc, _, err := generatePayment( htlcAmt, htlcAmt, uint32(htlcExpiry), blob, ) if err != nil { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f0c1c4252..cc4b49046 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -213,8 +213,6 @@ type Switch struct { pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex - paymentSequencer Sequencer - // control provides verification of sending htlc mesages control ControlTower @@ -293,16 +291,10 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { return nil, err } - sequencer, err := NewPersistentSequencer(cfg.DB) - if err != nil { - return nil, err - } - return &Switch{ bestHeight: currentHeight, cfg: &cfg, circuits: circuitMap, - paymentSequencer: sequencer, control: NewPaymentControl(false, cfg.DB), linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), @@ -354,8 +346,9 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro } // SendHTLC is used by other subsystems which aren't belong to htlc switch -// package in order to send the htlc update. -func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, +// package in order to send the htlc update. The paymentID used MUST be unique +// for this HTLC, otherwise the switch might reject it. +func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { @@ -376,11 +369,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, deobfuscator: deobfuscator, } - paymentID, err := s.paymentSequencer.NextID() - if err != nil { - return zeroPreimage, err - } - s.pendingMutex.Lock() s.pendingPayments[paymentID] = payment s.pendingMutex.Unlock() @@ -407,6 +395,7 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, // Returns channels so that other subsystem might wait/skip the // waiting of handling of payment. var preimage [sha256.Size]byte + var err error select { case e := <-payment.err: diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 75732a2d5..c6ea60a27 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, // We'll attempt to send out a new HTLC that has Alice as the first // outgoing link. This should fail as Alice isn't yet able to forward // any active HTLC's. - _, err = s.SendHTLC(aliceChannelLink.ShortChanID(), addMsg, nil) + _, err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil) if err == nil { t.Fatalf("local forward should fail due to inactive link") } @@ -1743,7 +1743,7 @@ func TestSwitchSendPayment(t *testing.T) { errChan := make(chan error) go func() { _, err := s.SendHTLC( - aliceChannelLink.ShortChanID(), update, + aliceChannelLink.ShortChanID(), 0, update, newMockDeobfuscator()) errChan <- err }() @@ -1752,7 +1752,7 @@ func TestSwitchSendPayment(t *testing.T) { // Send the payment with the same payment hash and same // amount and check that it will be propagated successfully _, err := s.SendHTLC( - aliceChannelLink.ShortChanID(), update, + aliceChannelLink.ShortChanID(), 0, update, newMockDeobfuscator(), ) errChan <- err diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2e1907f23..997daf5ec 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -543,7 +543,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, blob [lnwire.OnionPacketSize]byte, preimage, rhash [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, - error) { + uint64, error) { // Create the db invoice. Normally the payment requests needs to be set, // because it is decoded in InvoiceRegistry to obtain the cltv expiry. @@ -566,18 +566,25 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, OnionBlob: blob, } - return invoice, htlc, nil + pid, err := generateRandomBytes(8) + if err != nil { + return nil, nil, 0, err + } + paymentID := binary.BigEndian.Uint64(pid) + + return invoice, htlc, paymentID, nil } // generatePayment generates the htlc add request by given path blob and // invoice which should be added by destination peer. func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, - blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, error) { + blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, + *lnwire.UpdateAddHTLC, uint64, error) { var preimage [sha256.Size]byte r, err := generateRandomBytes(sha256.Size) if err != nil { - return nil, nil, err + return nil, nil, 0, err } copy(preimage[:], r) @@ -772,7 +779,9 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, } // Generate payment: invoice and htlc. - invoice, htlc, err := generatePayment(invoiceAmt, htlcAmt, timelock, blob) + invoice, htlc, pid, err := generatePayment( + invoiceAmt, htlcAmt, timelock, blob, + ) if err != nil { return nil, nil, err } @@ -786,7 +795,7 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. return invoice, func() error { _, err := sender.htlcSwitch.SendHTLC( - firstHop, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, newMockDeobfuscator(), ) return err }, nil @@ -1235,8 +1244,10 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, rhash := preimage.Hash() // Generate payment: invoice and htlc. - invoice, htlc, err := generatePaymentWithPreimage(invoiceAmt, htlcAmt, timelock, blob, - channeldb.UnknownPreimage, rhash) + invoice, htlc, pid, err := generatePaymentWithPreimage( + invoiceAmt, htlcAmt, timelock, blob, + channeldb.UnknownPreimage, rhash, + ) if err != nil { paymentErr <- err return paymentErr @@ -1251,7 +1262,7 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. go func() { _, err := sender.htlcSwitch.SendHTLC( - firstHop, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, newMockDeobfuscator(), ) paymentErr <- err }() diff --git a/routing/mock_test.go b/routing/mock_test.go index 2552b2fd1..c5754e4ee 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -14,6 +14,7 @@ type mockPaymentAttemptDispatcher struct { var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, + _ uint64, _ *lnwire.UpdateAddHTLC, _ htlcswitch.ErrorDecrypter) ([sha256.Size]byte, error) { diff --git a/routing/router.go b/routing/router.go index 3b1d84780..8ca616cc3 100644 --- a/routing/router.go +++ b/routing/router.go @@ -134,6 +134,7 @@ type PaymentAttemptDispatcher interface { // denoted by its public key. A non-nil error is to be returned if the // payment was unsuccessful. SendHTLC(firstHop lnwire.ShortChannelID, + paymentID uint64, htlcAdd *lnwire.UpdateAddHTLC, deobfuscator htlcswitch.ErrorDecrypter) ([sha256.Size]byte, error) } @@ -208,6 +209,12 @@ type Config struct { // returned. QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + // NextPaymentID is a method that guarantees to return a new, unique ID + // each time it is called. This is used by the router to generate a + // unique payment ID for each payment it attempts to send, such that + // the switch can properly handle the HTLC. + NextPaymentID func() (uint64, error) + // AssumeChannelValid toggles whether or not the router will check for // spentness of channel outpoints. For neutrino, this saves long rescans // from blocking initial usage of the daemon. @@ -1715,8 +1722,15 @@ func (r *ChannelRouter) sendToSwitch(route *route.Route, paymentHash [32]byte) ( OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), } + // We generate a new, unique payment ID that we will use for + // this HTLC. + paymentID, err := r.cfg.NextPaymentID() + if err != nil { + return [32]byte{}, err + } + return r.cfg.Payer.SendHTLC( - firstHop, htlcAdd, errorDecryptor, + firstHop, paymentID, htlcAdd, errorDecryptor, ) } diff --git a/routing/router_test.go b/routing/router_test.go index b1e343839..9f4b6aea2 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -6,6 +6,7 @@ import ( "image/color" "math/rand" "strings" + "sync/atomic" "testing" "time" @@ -23,6 +24,8 @@ import ( "github.com/lightningnetwork/lnd/zpay32" ) +var uniquePaymentID uint64 = 1 // to be used atomically + type testCtx struct { router *ChannelRouter @@ -90,6 +93,10 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { return lnwire.NewMSatFromSatoshis(e.Capacity) }, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, }) if err != nil { return nil, nil, fmt.Errorf("unable to create router %v", err) diff --git a/server.go b/server.go index 02683c8f9..cc2033ad6 100644 --- a/server.go +++ b/server.go @@ -609,6 +609,13 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, } s.currentNodeAnn = nodeAnn + // The router will get access to the payment ID sequencer, such that it + // can generate unique payment IDs. + sequencer, err := htlcswitch.NewPersistentSequencer(chanDB) + if err != nil { + return nil, err + } + s.chanRouter, err = routing.New(routing.Config{ Graph: chanGraph, Chain: cc.chainIO, @@ -646,6 +653,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, return link.Bandwidth() }, AssumeChannelValid: cfg.Routing.UseAssumeChannelValid(), + NextPaymentID: sequencer.NextID, }) if err != nil { return nil, fmt.Errorf("can't create router: %v", err)