diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 067f32924..2639ae322 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -140,6 +140,11 @@ type ChannelLinkConfig struct { // the exit node. // NOTE: HodlHTLC should be active in conjunction with DebugHTLC. HodlHTLC bool + + // SyncStates is used to indicate that we need send the channel + // reestablishment message to the remote peer. It should be done if our + // clients have been restarted, or remote peer have been reconnected. + SyncStates bool } // channelLink is the service which drives a channel's commitment update @@ -260,8 +265,9 @@ var _ ChannelLink = (*channelLink)(nil) // NOTE: Part of the ChannelLink interface. func (l *channelLink) Start() error { if !atomic.CompareAndSwapInt32(&l.started, 0, 1) { - log.Warnf("channel link(%v): already started", l) - return nil + err := errors.Errorf("channel link(%v): already started", l) + log.Warn(err) + return err } log.Infof("ChannelLink(%v) is starting", l) @@ -312,6 +318,29 @@ func (l *channelLink) htlcManager() { log.Infof("HTLC manager for ChannelPoint(%v) started, "+ "bandwidth=%v", l.channel.ChannelPoint(), l.Bandwidth()) + // If the link have been recreated, than we need to sync the states by + // sending the channel reestablishment message. + if l.cfg.SyncStates { + log.Infof("Syncing states for channel(%v) via sending the "+ + "re-establishment message", l.channel.ChannelPoint()) + + localCommitmentNumber, remoteRevocationNumber := l.channel.LastCounters() + + l.cfg.Peer.SendMessage(&lnwire.ChannelReestablish{ + ChanID: l.ChanID(), + NextLocalCommitmentNumber: localCommitmentNumber + 1, + NextRemoteRevocationNumber: remoteRevocationNumber + 1, + }) + + if err := l.channelInitialization(); err != nil { + err := errors.Errorf("unable to sync the states for channel(%v)"+ + "with remote node: %v", l.ChanID(), err) + log.Error(err) + l.cfg.Peer.Disconnect(err) + return + } + } + // TODO(roasbeef): check to see if able to settle any currently pending // HTLCs // * also need signals when new invoices are added by the @@ -469,6 +498,7 @@ out: l.handleUpstreamMsg(msg) case cmd := <-l.linkControl: + switch req := cmd.(type) { case *policyUpdate: // In order to avoid overriding a valid policy @@ -681,6 +711,30 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { // direct channel with, updating our respective commitment chains. func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { switch msg := msg.(type) { + case *lnwire.ChannelReestablish: + log.Infof("Received re-establishment message from remote side "+ + "for channel(%v)", l.channel.ChannelPoint()) + + messagesToSyncState, err := l.channel.ReceiveReestablish(msg) + if err != nil { + err := errors.Errorf("unable to handle upstream reestablish "+ + "message: %v", err) + log.Error(err) + l.cfg.Peer.Disconnect(err) + return + } + + // Send message to the remote side which are needed to synchronize + // the state. + log.Infof("Sending %v updates to synchronize the "+ + "state for channel(%v)", len(messagesToSyncState), + l.channel.ChannelPoint()) + for _, msg := range messagesToSyncState { + l.cfg.Peer.SendMessage(msg) + } + + return + case *lnwire.UpdateAddHTLC: // We just received an add request from an upstream peer, so we // add it to our state machine, then add the HTLC to our @@ -774,7 +828,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { l.cancelReasons[idx] = msg.Reason case *lnwire.CommitSig: - // We just received a new update to our local commitment chain, + // We just received a new updates to our local commitment chain, // validate this new commitment, closing the link if invalid. err := l.channel.ReceiveNewCommitment(msg.CommitSig, msg.HtlcSigs) if err != nil { @@ -1513,3 +1567,37 @@ func (l *channelLink) fail(format string, a ...interface{}) { log.Error(reason) l.cfg.Peer.Disconnect(reason) } + +// channelInitialization waits for channel synchronization message to +// be received from another side and handled. +func (l *channelLink) channelInitialization() error { + // Before we launch any of the helper goroutines off the channel link + // struct, we'll first ensure proper adherence to the p2p protocol. The + // channel reestablish message MUST be sent before any other message. + expired := time.After(time.Second * 5) + + for { + select { + case msg := <-l.upstream: + if msg, ok := msg.(*lnwire.ChannelReestablish); ok { + l.handleUpstreamMsg(msg) + return nil + } else { + return errors.New("very first message between nodes " + + "for channel link should be reestablish message") + } + + case pkt := <-l.downstream: + l.overflowQueue.consume(pkt) + + case cmd := <-l.linkControl: + l.handleControlCommand(cmd) + + // In order to avoid blocking indefinitely, we'll give the other peer + // an upper timeout of 5 seconds to respond before we bail out early. + case <-expired: + return errors.Errorf("peer did not complete handshake for channel " + + "link within 5 seconds") + } + } +} diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index fc9655ebb..5e373bde6 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -16,6 +16,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/roasbeef/btcd/chaincfg/chainhash" @@ -79,15 +80,6 @@ func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor { } if chanID == channelID { - // Skip logging of extend revocation window messages. - switch m := m.(type) { - case *lnwire.RevokeAndAck: - var zeroHash chainhash.Hash - if bytes.Equal(zeroHash[:], m.Revocation[:]) { - return false, nil - } - } - fmt.Printf("---------------------- \n %v received: "+ "%v", name, messageToString(m)) } @@ -98,13 +90,13 @@ func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor { // createInterceptorFunc creates the function by the given set of messages // which, checks the order of the messages and skip the ones which were // indicated to be intercepted. -func createInterceptorFunc(peer string, messages []expectedMessage, +func createInterceptorFunc(prefix, receiver string, messages []expectedMessage, chanID lnwire.ChannelID, debug bool) messageInterceptor { // Filter message which should be received with given peer name. var expectToReceive []expectedMessage for _, message := range messages { - if message.to == peer { + if message.to == receiver { expectToReceive = append(expectToReceive, message) } } @@ -128,17 +120,24 @@ func createInterceptorFunc(peer string, messages []expectedMessage, if expectedMessage.message.MsgType() != m.MsgType() { return false, errors.Errorf("%v received wrong message: \n"+ - "real: %v\nexpected: %v", peer, m.MsgType(), + "real: %v\nexpected: %v", receiver, m.MsgType(), expectedMessage.message.MsgType()) } if debug { + var postfix string + if revocation, ok := m.(*lnwire.RevokeAndAck); ok { + var zeroHash chainhash.Hash + if bytes.Equal(zeroHash[:], revocation.Revocation[:]) { + postfix = "- empty revocation" + } + } + if expectedMessage.skip { - fmt.Printf("'%v' skiped the received message: %v \n", - peer, m.MsgType()) + fmt.Printf("skipped: %v: %v %v \n", prefix, + m.MsgType(), postfix) } else { - fmt.Printf("'%v' received message: %v \n", peer, - m.MsgType()) + fmt.Printf("%v: %v %v \n", prefix, m.MsgType(), postfix) } } @@ -153,13 +152,17 @@ func createInterceptorFunc(peer string, messages []expectedMessage, func TestChannelLinkSingleHopPayment(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -189,7 +192,8 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { // * settle request to be sent back from bob to alice. // * alice<->bob commitment state to be updated. // * user notification to be sent. - invoice, err := n.makePayment(n.aliceServer, n.bobServer, + receiver := n.bobServer + rhash, err := n.makePayment(n.aliceServer, receiver, n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err != nil { @@ -203,8 +207,12 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { // Check that alice invoice was settled and bandwidth of HTLC // links was changed. + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get inveoice: %v", err) + } if !invoice.Terms.Settled { - t.Fatal("invoice wasn't settled") + t.Fatal("alice invoice wasn't settled") } if aliceBandwidthBefore-amount != n.aliceChannelLink.Bandwidth() { @@ -223,18 +231,21 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { func TestChannelLinkBidirectionalOneHopPayments(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } defer n.stop() - bobBandwidthBefore := n.firstBobChannelLink.Bandwidth() aliceBandwidthBefore := n.aliceChannelLink.Bandwidth() @@ -347,13 +358,17 @@ func TestChannelLinkBidirectionalOneHopPayments(t *testing.T) { func TestChannelLinkMultiHopPayment(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -398,7 +413,8 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { // * settle request to be sent back from Bob to Alice. // * Alice<->Bob commitment states to be updated. // * user notification to be sent. - invoice, err := n.makePayment(n.aliceServer, n.carolServer, + receiver := n.carolServer + rhash, err := n.makePayment(n.aliceServer, n.carolServer, n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err != nil { @@ -410,8 +426,12 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { // Check that Carol invoice was settled and bandwidth of HTLC // links were changed. + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get inveoice: %v", err) + } if !invoice.Terms.Settled { - t.Fatal("alice invoice wasn't settled") + t.Fatal("carol invoice haven't been settled") } expectedAliceBandwidth := aliceBandwidthBefore - htlcAmt @@ -446,13 +466,17 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { func TestExitNodeTimelockPayloadMismatch(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -468,7 +492,7 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { // the receiving node, instead we set it to be a random value. hops[0].OutgoingCTLV = 500 - _, err := n.makePayment(n.aliceServer, n.bobServer, + _, err = n.makePayment(n.aliceServer, n.bobServer, n.bobServer.PubKey(), hops, amount, htlcAmt, htlcExpiry).Wait(10 * time.Second) if err == nil { @@ -495,13 +519,17 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { func TestExitNodeAmountPayloadMismatch(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -517,7 +545,7 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) { // receiving node expects to receive. hops[0].AmountToForward = 1 - _, err := n.makePayment(n.aliceServer, n.bobServer, + _, err = n.makePayment(n.aliceServer, n.bobServer, n.bobServer.PubKey(), hops, amount, htlcAmt, htlcExpiry).Wait(10 * time.Second) if err == nil { @@ -536,13 +564,17 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) { func TestLinkForwardTimelockPolicyMismatch(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -560,7 +592,7 @@ func TestLinkForwardTimelockPolicyMismatch(t *testing.T) { // Next, we'll make the payment which'll send an HTLC with our // specified parameters to the first hop in the route. - _, err := n.makePayment(n.aliceServer, n.carolServer, + _, err = n.makePayment(n.aliceServer, n.carolServer, n.bobServer.PubKey(), hops, amount, htlcAmt, htlcExpiry).Wait(10 * time.Second) // We should get an error, and that error should indicate that the HTLC @@ -588,13 +620,17 @@ func TestLinkForwardTimelockPolicyMismatch(t *testing.T) { func TestLinkForwardFeePolicyMismatch(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*3, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -612,7 +648,7 @@ func TestLinkForwardFeePolicyMismatch(t *testing.T) { // Next, we'll make the payment which'll send an HTLC with our // specified parameters to the first hop in the route. - _, err := n.makePayment(n.aliceServer, n.bobServer, + _, err = n.makePayment(n.aliceServer, n.bobServer, n.bobServer.PubKey(), hops, amountNoFee, amountNoFee, htlcExpiry).Wait(10 * time.Second) @@ -641,13 +677,17 @@ func TestLinkForwardFeePolicyMismatch(t *testing.T) { func TestLinkForwardMinHTLCPolicyMismatch(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -665,7 +705,7 @@ func TestLinkForwardMinHTLCPolicyMismatch(t *testing.T) { // Next, we'll make the payment which'll send an HTLC with our // specified parameters to the first hop in the route. - _, err := n.makePayment(n.aliceServer, n.bobServer, + _, err = n.makePayment(n.aliceServer, n.bobServer, n.bobServer.PubKey(), hops, amountNoFee, htlcAmt, htlcExpiry).Wait(10 * time.Second) @@ -695,13 +735,17 @@ func TestLinkForwardMinHTLCPolicyMismatch(t *testing.T) { func TestUpdateForwardingPolicy(t *testing.T) { t.Parallel() + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*5, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -730,9 +774,14 @@ func TestUpdateForwardingPolicy(t *testing.T) { // Carol's invoice should now be shown as settled as the payment // succeeded. - if !invoice.Terms.Settled { - t.Fatal("carol's invoice wasn't settled") + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get invoice: %v", err) } + if !invoice.Terms.Settled { + t.Fatal("carol invoice haven't been settled") + } + expectedAliceBandwidth := aliceBandwidthBefore - htlcAmt if expectedAliceBandwidth != n.aliceChannelLink.Bandwidth() { t.Fatalf("channel bandwidth incorrect: expected %v, got %v", @@ -790,13 +839,17 @@ func TestUpdateForwardingPolicy(t *testing.T) { func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*3, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*3) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) } @@ -817,7 +870,9 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { // * Bob trying to add HTLC add request in Bob<->Carol channel. // * Cancel HTLC request to be sent back from Bob to Alice. // * user notification to be sent. - invoice, err := n.makePayment(n.aliceServer, n.carolServer, + + receiver := n.carolServer + rhash, err := n.makePayment(n.aliceServer, n.carolServer, n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err == nil { @@ -833,8 +888,12 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get inveoice: %v", err) + } if invoice.Terms.Settled { - t.Fatal("alice invoice was settled") + t.Fatal("carol invoice have been settled") } if n.aliceChannelLink.Bandwidth() != aliceBandwidthBefore { @@ -863,13 +922,17 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*3, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) } @@ -949,13 +1012,17 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, - btcutil.SatoshiPerBitcoin*3, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -971,8 +1038,8 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { n.firstBobChannelLink, n.carolChannelLink) davePub := newMockServer("save", serverErr).PubKey() - - invoice, err := n.makePayment(n.aliceServer, n.bobServer, davePub, hops, + receiver := n.bobServer + rhash, err := n.makePayment(n.aliceServer, n.bobServer, davePub, hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err == nil { t.Fatal("error haven't been received") @@ -987,8 +1054,12 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get inveoice: %v", err) + } if invoice.Terms.Settled { - t.Fatal("alice invoice was settled") + t.Fatal("carol invoice have been settled") } if n.aliceChannelLink.Bandwidth() != aliceBandwidthBefore { @@ -1017,13 +1088,17 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { func TestChannelLinkMultiHopDecodeError(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) } @@ -1044,7 +1119,8 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, n.firstBobChannelLink, n.carolChannelLink) - invoice, err := n.makePayment(n.aliceServer, n.carolServer, + receiver := n.carolServer + rhash, err := n.makePayment(n.aliceServer, n.carolServer, n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err == nil { @@ -1067,8 +1143,12 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. + invoice, err := receiver.registry.LookupInvoice(rhash) + if err != nil { + t.Fatalf("unable to get inveoice: %v", err) + } if invoice.Terms.Settled { - t.Fatal("alice invoice was settled") + t.Fatal("carol invoice have been settled") } if n.aliceChannelLink.Bandwidth() != aliceBandwidthBefore { @@ -1100,12 +1180,18 @@ func TestChannelLinkExpiryTooSoonExitNode(t *testing.T) { // The starting height for this test will be 200. So we'll base all // HTLC starting points off of that. - const startingHeight = 200 - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - startingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + const startingHeight = 200 + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, startingHeight) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) } @@ -1119,8 +1205,9 @@ func TestChannelLinkExpiryTooSoonExitNode(t *testing.T) { startingHeight-10, n.firstBobChannelLink) // Now we'll send out the payment from Alice to Bob. - _, err := n.makePayment(n.aliceServer, n.bobServer, - n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock) + _, err = n.makePayment(n.aliceServer, n.bobServer, + n.bobServer.PubKey(), hops, amount, htlcAmt, + totalTimelock).Wait(time.Second) // The payment should've failed as the time lock value was in the // _past_. @@ -1150,12 +1237,18 @@ func TestChannelLinkExpiryTooSoonMidNode(t *testing.T) { // The starting height for this test will be 200. So we'll base all // HTLC starting points off of that. - const startingHeight = 200 - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - startingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + const startingHeight = 200 + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, startingHeight) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) } @@ -1170,8 +1263,9 @@ func TestChannelLinkExpiryTooSoonMidNode(t *testing.T) { startingHeight-10, n.firstBobChannelLink, n.carolChannelLink) // Now we'll send out the payment from Alice to Bob. - _, err := n.makePayment(n.aliceServer, n.bobServer, - n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock) + _, err = n.makePayment(n.aliceServer, n.bobServer, + n.bobServer.PubKey(), hops, amount, htlcAmt, + totalTimelock).Wait(time.Second) // The payment should've failed as the time lock value was in the // _past_. @@ -1199,17 +1293,24 @@ func TestChannelLinkExpiryTooSoonMidNode(t *testing.T) { func TestChannelLinkSingleHopMessageOrdering(t *testing.T) { t.Parallel() - serverErr := make(chan error, 4) - n := newThreeHopNetwork(t, + channels, cleanUp, _, err := createClusterChannels( btcutil.SatoshiPerBitcoin*3, - btcutil.SatoshiPerBitcoin*5, - serverErr, - testStartingHeight, - ) + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + serverErr := make(chan error, 4) + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) chanID := n.aliceChannelLink.ChanID() messages := []expectedMessage{ + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + {"alice", "bob", &lnwire.UpdateAddHTLC{}, false}, {"alice", "bob", &lnwire.CommitSig{}, false}, {"bob", "alice", &lnwire.RevokeAndAck{}, false}, @@ -1235,12 +1336,12 @@ func TestChannelLinkSingleHopMessageOrdering(t *testing.T) { } // Check that alice receives messages in right order. - n.aliceServer.intersect(createInterceptorFunc("alice", messages, chanID, - false)) + n.aliceServer.intersect(createInterceptorFunc("[alice] <-- [bob]", + "alice", messages, chanID, false)) // Check that bob receives messages in right order. - n.bobServer.intersect(createInterceptorFunc("bob", messages, chanID, - false)) + n.bobServer.intersect(createInterceptorFunc("[alice] --> [bob]", + "bob", messages, chanID, false)) if err := n.start(); err != nil { t.Fatalf("unable to start three hop network: %v", err) @@ -1257,7 +1358,8 @@ func TestChannelLinkSingleHopMessageOrdering(t *testing.T) { // * settle request to be sent back from bob to alice. // * alice<->bob commitment state to be updated. // * user notification to be sent. - _, err = n.makePayment(n.aliceServer, n.bobServer, + receiver := n.bobServer + _, err = n.makePayment(n.aliceServer, receiver, n.bobServer.PubKey(), hops, amount, htlcAmt, totalTimelock).Wait(10 * time.Second) if err != nil { @@ -1598,3 +1700,236 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) { coreLink.overflowQueue.Length()) } } + +var retransmissionTests = []struct { + name string + messages []expectedMessage +}{ + { + // Tests the ability of the channel links states to be synchronized + // after remote node haven't receive revoke and ack message. + name: "intercept last alice revoke_and_ack", + messages: []expectedMessage{ + // First initialization of the channel. + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // Send payment from Alice to Bob and intercept the last revocation + // message, in this case Bob should not proceed the payment farther. + {"alice", "bob", &lnwire.UpdateAddHTLC{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, true}, + + // Reestablish messages exchange on nodes restart. + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // Alice should resend the revoke_and_ack message to Bob because Bob + // claimed it in the reestbalish message. + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + + // Proceed the payment farther by sending the fulfilment message and + // trigger the state update. + {"bob", "alice", &lnwire.UpdateFufillHTLC{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + }, + }, + { + // Tests the ability of the channel links states to be synchronized + // after remote node haven't receive revoke and ack message. + name: "intercept bob revoke_and_ack commit_sig messages", + messages: []expectedMessage{ + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // Send payment from Alice to Bob and intercept the last revocation + // message, in this case Bob should not proceed the payment farther. + {"alice", "bob", &lnwire.UpdateAddHTLC{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + + // Intercept bob commit sig and revoke and ack messages. + {"bob", "alice", &lnwire.RevokeAndAck{}, true}, + {"bob", "alice", &lnwire.CommitSig{}, true}, + + // Reestablish messages exchange on nodes restart. + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // Bob should resend previously intercepted messages. + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + + // Proceed the payment farther by sending the fulfilment message and + // trigger the state update. + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + {"bob", "alice", &lnwire.UpdateFufillHTLC{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + }, + }, + { + // Tests the ability of the channel links states to be synchronized + // after remote node haven't receive update and commit sig messages. + name: "intercept update add htlc and commit sig messages", + messages: []expectedMessage{ + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // Attempt make a payment from Alice to Bob, which is intercepted, + // emulating the Bob server abrupt stop. + {"alice", "bob", &lnwire.UpdateAddHTLC{}, true}, + {"alice", "bob", &lnwire.CommitSig{}, true}, + + // Restart of the nodes, and after that nodes should exchange the + // reestablish messages. + {"alice", "bob", &lnwire.ChannelReestablish{}, false}, + {"bob", "alice", &lnwire.ChannelReestablish{}, false}, + + // After Bob has notified Alice that he didn't receive updates Alice + // should re-send them. + {"alice", "bob", &lnwire.UpdateAddHTLC{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + + {"bob", "alice", &lnwire.UpdateFufillHTLC{}, false}, + {"bob", "alice", &lnwire.CommitSig{}, false}, + {"alice", "bob", &lnwire.RevokeAndAck{}, false}, + {"alice", "bob", &lnwire.CommitSig{}, false}, + {"bob", "alice", &lnwire.RevokeAndAck{}, false}, + }, + }, +} + +// TestChannelRetransmission tests the ability of the channel links to +// synchronize theirs states after abrupt disconnect. +func TestChannelRetransmission(t *testing.T) { + t.Parallel() + + paymentWithRestart := func(t *testing.T, messages []expectedMessage) { + channels, cleanUp, restoreChannelsFromDb, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*5, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + chanID := lnwire.NewChanIDFromOutPoint(channels.aliceToBob.ChannelPoint()) + serverErr := make(chan error, 4) + + aliceInterceptor := createInterceptorFunc("[alice] <-- [bob]", + "alice", messages, chanID, false) + bobInterceptor := createInterceptorFunc("[alice] --> [bob]", + "bob", messages, chanID, false) + + // Add interceptor to check the order of Bob and Alice messages. + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) + n.aliceServer.intersect(aliceInterceptor) + n.bobServer.intersect(bobInterceptor) + if err := n.start(); err != nil { + t.Fatalf("unable to start three hop network: %v", err) + } + defer n.stop() + + bobBandwidthBefore := n.firstBobChannelLink.Bandwidth() + aliceBandwidthBefore := n.aliceChannelLink.Bandwidth() + + amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) + htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, + n.firstBobChannelLink) + + // Send payment which should fail because we intercept the update and + // commit messages. + receiver := n.bobServer + rhash, err := n.makePayment(n.aliceServer, receiver, + n.bobServer.PubKey(), hops, amount, htlcAmt, + totalTimelock).Wait(time.Millisecond * 100) + if err == nil { + t.Fatalf("payment shouldn't haven been finished") + } + + // Stop network cluster and create new one, with the old channels + // states. Also do the *hack* - save the payment receiver to pass it + // in new channel link, otherwise payment will be failed because of the + // unknown payment hash. Hack will be removed with sphinx payment. + bobRegistry := n.bobServer.registry + n.stop() + + channels, err = restoreChannelsFromDb() + if err != nil { + t.Fatalf("unable to restore channels from database: %v", err) + } + + n = newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, serverErr, testStartingHeight) + n.firstBobChannelLink.cfg.Registry = bobRegistry + n.aliceServer.intersect(aliceInterceptor) + n.bobServer.intersect(bobInterceptor) + + if err := n.start(); err != nil { + t.Fatalf("unable to start three hop network: %v", err) + } + defer n.stop() + + // Wait for reestablishment to be proceeded and invoice to be settled. + // TODO(andrew.shvv) Will be removed if we move the notification center + // to the channel link itself. + + var invoice *channeldb.Invoice + for i := 0; i < 20; i++ { + select { + case <-time.After(time.Millisecond * 200): + case serverErr := <-serverErr: + t.Fatalf("server error: %v", serverErr) + } + + // Check that alice invoice wasn't settled and bandwidth of htlc + // links hasn't been changed. + invoice, err = receiver.registry.LookupInvoice(rhash) + if err != nil { + err = errors.Errorf("unable to get invoice: %v", err) + continue + } + if !invoice.Terms.Settled { + err = errors.Errorf("alice invoice haven't been settled") + continue + } + + if aliceBandwidthBefore-amount != n.aliceChannelLink.Bandwidth() { + err = errors.Errorf("alice bandwidth should have been increased" + + " on payment amount") + continue + } + + if bobBandwidthBefore+amount != n.firstBobChannelLink.Bandwidth() { + err = errors.Errorf("bob bandwidth should have been increased " + + "on payment amount") + continue + } + + break + } + + if err != nil { + t.Fatal(err) + } + + } + + for _, test := range retransmissionTests { + t.Run(test.name, func(t *testing.T) { + paymentWithRestart(t, test.messages) + }) + } +} diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index e7b062024..7458f8104 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -62,15 +62,21 @@ func newMockServer(name string, errChan chan error) *mockServer { func (s *mockServer) Start() error { if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { - return nil + return errors.New("mock server already started") } - s.htlcSwitch.Start() + if err := s.htlcSwitch.Start(); err != nil { + return err + } s.wg.Add(1) go func() { defer s.wg.Done() + defer func() { + s.htlcSwitch.Stop() + }() + for { select { case msg := <-s.messages: @@ -79,8 +85,8 @@ func (s *mockServer) Start() error { for _, interceptor := range s.interceptorFuncs { skip, err := interceptor(msg) if err != nil { - s.errChan <- errors.Errorf("%v: error in the "+ - "interceptor: %v", s.name, err) + s.fail(errors.Errorf("%v: error in the "+ + "interceptor: %v", s.name, err)) return } shouldSkip = shouldSkip || skip @@ -91,7 +97,8 @@ func (s *mockServer) Start() error { } if err := s.readHandler(msg); err != nil { - s.errChan <- errors.Errorf("%v server error: %v", s.name, err) + s.fail(err) + return } case <-s.quit: return @@ -102,6 +109,16 @@ func (s *mockServer) Start() error { return nil } +func (s *mockServer) fail(err error) { + go func() { + s.Stop() + }() + + go func() { + s.errChan <- errors.Errorf("%v server error: %v", s.name, err) + }() +} + // mockHopIterator represents the test version of hop iterator which instead // of encrypting the path in onion blob just stores the path as a list of hops. type mockHopIterator struct { @@ -266,6 +283,7 @@ func (s *mockServer) SendMessage(message lnwire.Message) error { select { case s.messages <- message: case <-s.quit: + return errors.New("server is stopped") } return nil @@ -290,6 +308,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error { case *lnwire.FundingLocked: // Ignore return nil + case *lnwire.ChannelReestablish: + targetChan = msg.ChanID default: return errors.New("unknown message type") } @@ -323,24 +343,22 @@ func (s *mockServer) PubKey() [33]byte { func (s *mockServer) Disconnect(reason error) { fmt.Printf("server %v disconnected due to %v\n", s.name, reason) - s.Stop() - s.errChan <- errors.Errorf("server %v was disconnected: %v", s.name, - reason) + s.fail(errors.Errorf("server %v was disconnected: %v", s.name, reason)) } func (s *mockServer) WipeChannel(*lnwallet.LightningChannel) error { return nil } -func (s *mockServer) Stop() { +func (s *mockServer) Stop() error { if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { - return + return nil } - s.htlcSwitch.Stop() - close(s.quit) s.wg.Wait() + + return nil } func (s *mockServer) String() string { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 0e1ae2b49..e31af97d1 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -211,14 +211,16 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, case e := <-payment.err: err = e case <-s.quit: - return zeroPreimage, errors.New("switch is shutting down") + return zeroPreimage, errors.New("htlc switch have been stopped " + + "while waiting for payment result") } select { case p := <-payment.preimage: preimage = p case <-s.quit: - return zeroPreimage, errors.New("switch is shutting down") + return zeroPreimage, errors.New("htlc switch have been stopped " + + "while waiting for payment result") } return preimage, err @@ -316,7 +318,8 @@ func (s *Switch) forward(packet *htlcPacket) error { case err := <-command.err: return err case <-s.quit: - return errors.New("Htlc Switch was stopped") + return errors.New("unable to forward htlc packet htlc switch was " + + "stopped") } } @@ -803,7 +806,7 @@ func (s *Switch) htlcForwarder() { func (s *Switch) Start() error { if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { log.Warn("Htlc Switch already started") - return nil + return errors.New("htlc switch already started") } log.Infof("Starting HTLC Switch") @@ -819,10 +822,10 @@ func (s *Switch) Start() error { func (s *Switch) Stop() error { if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { log.Warn("Htlc Switch already stopped") - return nil + return errors.New("htlc switch already shutdown") } - log.Infof("HLTC Switch shutting down") + log.Infof("HTLC Switch shutting down") close(s.quit) s.wg.Wait() @@ -849,7 +852,7 @@ func (s *Switch) AddLink(link ChannelLink) error { case s.linkControl <- command: return <-command.err case <-s.quit: - return errors.New("Htlc Switch was stopped") + return errors.New("unable to add link htlc switch was stopped") } } @@ -903,7 +906,7 @@ func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) { case s.linkControl <- command: return <-command.done, <-command.err case <-s.quit: - return nil, errors.New("Htlc Switch was stopped") + return nil, errors.New("unable to get link htlc switch was stopped") } } @@ -947,7 +950,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { case s.linkControl <- command: return <-command.err case <-s.quit: - return errors.New("Htlc Switch was stopped") + return errors.New("unable to remove link htlc switch was stopped") } } @@ -994,7 +997,7 @@ func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) { case s.linkControl <- command: return <-command.done, <-command.err case <-s.quit: - return nil, errors.New("Htlc Switch was stopped") + return nil, errors.New("unable to get links htlc switch was stopped") } } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index be1f56fd9..a22dc8a43 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -14,6 +14,8 @@ import ( "math/big" + "net" + "github.com/btcsuite/fastsha256" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" @@ -84,7 +86,9 @@ func generateRandomBytes(n int) ([]byte, error) { // TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase func createTestChannel(alicePrivKey, bobPrivKey []byte, aliceAmount, bobAmount btcutil.Amount, - chanID lnwire.ShortChannelID) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) { + chanID lnwire.ShortChannelID) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), + func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel, + error), error) { aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), alicePrivKey) bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), bobPrivKey) @@ -98,7 +102,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, var hash [sha256.Size]byte randomSeed, err := generateRandomBytes(sha256.Size) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } copy(hash[:], randomSeed) @@ -133,7 +137,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobPreimageProducer := shachain.NewRevocationProducer(bobRoot) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } bobCommitPoint := lnwallet.ComputeCommitmentPoint(bobFirstRevoke[:]) @@ -141,7 +145,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, alicePreimageProducer := shachain.NewRevocationProducer(aliceRoot) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } aliceCommitPoint := lnwallet.ComputeCommitmentPoint(aliceFirstRevoke[:]) @@ -149,19 +153,19 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint, bobCommitPoint, fundingTxIn) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } alicePath, err := ioutil.TempDir("", "alicedb") dbAlice, err := channeldb.Open(alicePath) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } bobPath, err := ioutil.TempDir("", "bobdb") dbBob, err := channeldb.Open(bobPath) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } estimator := &lnwallet.StaticFeeEstimator{ @@ -171,6 +175,17 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, feePerKw := btcutil.Amount(estimator.EstimateFeePerWeight(1) * 1000) commitFee := (feePerKw * btcutil.Amount(724)) / 1000 + const broadcastHeight = 1 + bobAddr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + + aliceAddr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + aliceChannelState := &channeldb.OpenChannel{ LocalChanCfg: aliceCfg, RemoteChanCfg: bobCfg, @@ -191,6 +206,11 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, ShortChanID: chanID, Db: dbAlice, } + + if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil { + return nil, nil, nil, nil, err + } + bobChannelState := &channeldb.OpenChannel{ LocalChanCfg: bobCfg, RemoteChanCfg: aliceCfg, @@ -212,6 +232,10 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, Db: dbBob, } + if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil { + return nil, nil, nil, nil, err + } + cleanUpFunc := func() { os.RemoveAll(bobPath) os.RemoveAll(alicePath) @@ -223,33 +247,87 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, channelAlice, err := lnwallet.NewLightningChannel(aliceSigner, nil, estimator, aliceChannelState) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } channelBob, err := lnwallet.NewLightningChannel(bobSigner, nil, estimator, bobChannelState) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } // Now that the channel are open, simulate the start of a session by // having Alice and Bob extend their revocation windows to each other. aliceNextRevoke, err := channelAlice.NextRevocationKey() if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } bobNextRevoke, err := channelBob.NextRevocationKey() if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } - return channelAlice, channelBob, cleanUpFunc, nil + restore := func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel, + error) { + aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) + if err != nil { + return nil, nil, errors.Errorf("unable to fetch alice channel: "+ + "%v", err) + } + + var aliceStoredChannel *channeldb.OpenChannel + for _, channel := range aliceStoredChannels { + if channel.FundingOutpoint.String() == prevOut.String() { + aliceStoredChannel = channel + break + } + } + + if aliceStoredChannel == nil { + return nil, nil, errors.New("unable to find stored alice channel") + } + + newAliceChannel, err := lnwallet.NewLightningChannel(aliceSigner, + nil, estimator, aliceStoredChannel) + if err != nil { + return nil, nil, errors.Errorf("unable to create new channel: %v", + err) + } + + bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) + if err != nil { + return nil, nil, errors.Errorf("unable to fetch bob channel: "+ + "%v", err) + } + + var bobStoredChannel *channeldb.OpenChannel + for _, channel := range bobStoredChannels { + if channel.FundingOutpoint.String() == prevOut.String() { + bobStoredChannel = channel + break + } + } + + if bobStoredChannel == nil { + return nil, nil, errors.New("unable to find stored bob channel") + } + + newBobChannel, err := lnwallet.NewLightningChannel(bobSigner, nil, + estimator, bobStoredChannel) + if err != nil { + return nil, nil, errors.Errorf("unable to create new channel: %v", + err) + } + return newAliceChannel, newBobChannel, nil + } + + return channelAlice, channelBob, cleanUpFunc, restore, nil } // getChanID retrieves the channel point from nwire message. @@ -337,9 +415,6 @@ type threeHopNetwork struct { carolChannelLink *channelLink carolServer *mockServer - firstChannelCleanup func() - secondChannelCleanup func() - globalPolicy ForwardingPolicy } @@ -403,17 +478,17 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32, } type paymentResponse struct { - invoice *channeldb.Invoice - err chan error + rhash chainhash.Hash + err chan error } -func (r *paymentResponse) Wait(d time.Duration) (*channeldb.Invoice, error) { +func (r *paymentResponse) Wait(d time.Duration) (chainhash.Hash, error) { select { case err := <-r.err: close(r.err) - return r.invoice, err + return r.rhash, err case <-time.After(d): - return r.invoice, errors.New("htlc was no settled in time") + return r.rhash, errors.New("htlc was no settled in time") } } @@ -432,6 +507,8 @@ func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer, paymentErr := make(chan error, 1) + var rhash chainhash.Hash + sender := sendingPeer.(*mockServer) receiver := receivingPeer.(*mockServer) @@ -441,42 +518,41 @@ func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer, if err != nil { paymentErr <- err return &paymentResponse{ - invoice: nil, - err: paymentErr, + rhash: rhash, + err: paymentErr, } } // Generate payment: invoice and htlc. - invoice, htlc, err := generatePayment(invoiceAmt, htlcAmt, timelock, - blob) + invoice, htlc, err := generatePayment(invoiceAmt, htlcAmt, timelock, blob) if err != nil { paymentErr <- err return &paymentResponse{ - invoice: nil, - err: paymentErr, + rhash: rhash, + err: paymentErr, } } + rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:]) // Check who is last in the route and add invoice to server registry. if err := receiver.registry.AddInvoice(invoice); err != nil { paymentErr <- err return &paymentResponse{ - invoice: invoice, - err: paymentErr, + rhash: rhash, + err: paymentErr, } } // Send payment and expose err channel. - errChan := make(chan error) go func() { _, err := sender.htlcSwitch.SendHTLC(firstHopPub, htlc, newMockDeobfuscator()) - errChan <- err + paymentErr <- err }() return &paymentResponse{ - invoice: invoice, - err: errChan, + rhash: rhash, + err: paymentErr, } } @@ -516,9 +592,70 @@ func (n *threeHopNetwork) stop() { for i := 0; i < 3; i++ { <-done } +} - n.firstChannelCleanup() - n.secondChannelCleanup() +// clusterChannels... +type clusterChannels struct { + aliceToBob *lnwallet.LightningChannel + bobToAlice *lnwallet.LightningChannel + bobToCarol *lnwallet.LightningChannel + carolToBob *lnwallet.LightningChannel +} + +// createClusterChannels creates lightning channels which are needed for +// network cluster to be initialized. +func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( + *clusterChannels, func(), func() (*clusterChannels, error), error) { + + firstChanID := lnwire.NewShortChanIDFromInt(4) + secondChanID := lnwire.NewShortChanIDFromInt(5) + + // Create lightning channels between Alice<->Bob and Bob<->Carol + aliceChannel, firstBobChannel, cleanAliceBob, restoreAliceBob, err := createTestChannel( + alicePrivKey, bobPrivKey, aliceToBob, aliceToBob, firstChanID) + if err != nil { + return nil, nil, nil, errors.Errorf("unable to create "+ + "alice<->bob channel: %v", err) + } + + secondBobChannel, carolChannel, cleanBobCarol, restoreBobCarol, err := createTestChannel( + bobPrivKey, carolPrivKey, bobToCarol, bobToCarol, secondChanID) + if err != nil { + cleanAliceBob() + return nil, nil, nil, errors.Errorf("unable to create "+ + "bob<->carol channel: %v", err) + } + + cleanUp := func() { + cleanAliceBob() + cleanBobCarol() + } + + restoreFromDb := func() (*clusterChannels, error) { + a2b, b2a, err := restoreAliceBob() + if err != nil { + return nil, err + } + + b2c, c2b, err := restoreBobCarol() + if err != nil { + return nil, err + } + + return &clusterChannels{ + aliceToBob: a2b, + bobToAlice: b2a, + bobToCarol: b2c, + carolToBob: c2b, + }, nil + } + + return &clusterChannels{ + aliceToBob: aliceChannel, + bobToAlice: firstBobChannel, + bobToCarol: secondBobChannel, + carolToBob: carolChannel, + }, cleanUp, restoreFromDb, nil } // newThreeHopNetwork function creates the following topology and returns the @@ -534,10 +671,10 @@ func (n *threeHopNetwork) stop() { // alice first bob second bob carol // channel link channel link channel link channel link // -func newThreeHopNetwork(t *testing.T, aliceToBob, - bobToCarol btcutil.Amount, serverErr chan error, +func newThreeHopNetwork(t *testing.T, aliceChannel, firstBobChannel, + secondBobChannel, carolChannel *lnwallet.LightningChannel, + serverErr chan error, startingHeight uint32) *threeHopNetwork { - var err error // Create three peers/servers. aliceServer := newMockServer("alice", serverErr) @@ -548,22 +685,6 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, // route which htlc should follow. decoder := &mockIteratorDecoder{} - firstChanID := lnwire.NewShortChanIDFromInt(4) - secondChanID := lnwire.NewShortChanIDFromInt(5) - - // Create lightning channels between Alice<->Bob and Bob<->Carol - aliceChannel, firstBobChannel, fCleanUp, err := createTestChannel( - alicePrivKey, bobPrivKey, aliceToBob, aliceToBob, firstChanID) - if err != nil { - t.Fatalf("unable to create alice<->bob channel: %v", err) - } - - secondBobChannel, carolChannel, sCleanUp, err := createTestChannel( - bobPrivKey, carolPrivKey, bobToCarol, bobToCarol, secondChanID) - if err != nil { - t.Fatalf("unable to create bob<->carol channel: %v", err) - } - globalEpoch := &chainntnfs.BlockEpochEvent{ Epochs: make(chan *chainntnfs.BlockEpoch), Cancel: func() { @@ -588,6 +709,7 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, GetLastChannelUpdate: mockGetChanUpdateMessage, Registry: aliceServer.registry, BlockEpochs: globalEpoch, + SyncStates: true, }, aliceChannel, startingHeight, @@ -609,6 +731,7 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, GetLastChannelUpdate: mockGetChanUpdateMessage, Registry: bobServer.registry, BlockEpochs: globalEpoch, + SyncStates: true, }, firstBobChannel, startingHeight, @@ -630,6 +753,7 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, GetLastChannelUpdate: mockGetChanUpdateMessage, Registry: bobServer.registry, BlockEpochs: globalEpoch, + SyncStates: true, }, secondBobChannel, startingHeight, @@ -651,6 +775,7 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, GetLastChannelUpdate: mockGetChanUpdateMessage, Registry: carolServer.registry, BlockEpochs: globalEpoch, + SyncStates: true, }, carolChannel, startingHeight, @@ -668,9 +793,6 @@ func newThreeHopNetwork(t *testing.T, aliceToBob, carolChannelLink: carolChannelLink.(*channelLink), carolServer: carolServer, - firstChannelCleanup: fCleanUp, - secondChannelCleanup: sCleanUp, - globalPolicy: globalPolicy, } } diff --git a/lnd_test.go b/lnd_test.go index 81aa54bdd..e5c8af866 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -3923,6 +3923,170 @@ func testBidirectionalAsyncPayments(net *networkHarness, t *harnessTest) { closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) } +// testChannelReestablishment... +func testChannelReestablishment(net *networkHarness, t *harnessTest) { + ctxb := context.Background() + + // As we'll be querying the channels state frequently we'll + // create a closure helper function for the purpose. + getChanInfo := func(node *lightningNode) (*lnrpc.ActiveChannel, error) { + req := &lnrpc.ListChannelsRequest{} + channelInfo, err := node.ListChannels(ctxb, req) + if err != nil { + return nil, err + } + if len(channelInfo.Channels) != 1 { + t.Fatalf("node should only have a single channel, "+ + "instead he has %v", + len(channelInfo.Channels)) + } + + return channelInfo.Channels[0], nil + } + + const ( + timeout = time.Duration(time.Second * 5) + paymentAmt = 100 + ) + + // First establish a channel with a capacity equals to the overall + // amount of payments, between Alice and Bob, at the end of the test + // Alice should send all money from her side to Bob. + ctxt, _ := context.WithTimeout(ctxb, timeout) + chanPoint := openChannelAndAssert(ctxt, t, net, net.Alice, net.Bob, + paymentAmt*500, 0) + + info, err := getChanInfo(net.Alice) + if err != nil { + t.Fatalf("unable to get alice channel info: %v", err) + } + + // Calculate the number of invoices. + numInvoices := int(info.LocalBalance / paymentAmt) + + // Initialize seed random in order to generate invoices. + rand.Seed(time.Now().UnixNano()) + + // With the channel open, we'll create a invoices for Bob that + // Alice will pay to in order to advance the state of the channel. + bobPaymentHashes := make([][]byte, numInvoices) + for i := 0; i < numInvoices; i++ { + preimage := make([]byte, 32) + _, err := rand.Read(preimage) + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } + + invoice := &lnrpc.Invoice{ + Memo: "testing", + RPreimage: preimage, + Value: paymentAmt, + } + resp, err := net.Bob.AddInvoice(ctxb, invoice) + if err != nil { + t.Fatalf("unable to add invoice: %v", err) + } + + bobPaymentHashes[i] = resp.RHash + } + + // Wait for Alice to receive the channel edge from the funding manager. + ctxt, _ = context.WithTimeout(ctxb, timeout) + err = net.Alice.WaitForNetworkChannelOpen(ctxt, chanPoint) + if err != nil { + t.Fatalf("alice didn't see the alice->bob channel before "+ + "timeout: %v", err) + } + + // Open up a payment stream to Alice that we'll use to send payment to + // Bob. We also create a small helper function to send payments to Bob, + // consuming the payment hashes we generated above. + ctxt, _ = context.WithTimeout(ctxb, timeout) + alicePayStream, err := net.Alice.SendPayment(ctxt) + if err != nil { + t.Fatalf("unable to create payment stream for alice: %v", err) + } + + // Send payments from Alice to Bob using of Bob's payment hashes + // generated above. + for i := 0; i < numInvoices-1; i++ { + sendReq := &lnrpc.SendRequest{ + PaymentHash: bobPaymentHashes[i], + Dest: net.Bob.PubKey[:], + Amt: paymentAmt, + } + + if err := alicePayStream.Send(sendReq); err != nil { + t.Fatalf("unable to send payment: "+ + "stream has been closed: %v", err) + } + } + + // We should receive one insufficient capacity error, because we are + // sending on one invoice bigger. + + for i := 0; i < numInvoices/2; i++ { + if resp, err := alicePayStream.Recv(); err != nil { + t.Fatalf("payment stream has been closed: %v", err) + } else if resp.PaymentError != "" { + t.Fatalf("unable to finish the payment: %v", err) + } + } + + errChan := make(chan error) + go func() { + errChan <- net.Bob.Restart(net.lndErrorChan, nil) + }() + go func() { + errChan <- net.Alice.Restart(net.lndErrorChan, nil) + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatalf("unable to restart node: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("unable to restart node: timeout") + } + } + + // Open up a payment stream to Alice that we'll use to send payment to + // Bob. We also create a small helper function to send payments to Bob, + // consuming the payment hashes we generated above. + ctxt, _ = context.WithTimeout(ctxb, timeout) + alicePayStream, err = net.Alice.SendPayment(ctxt) + if err != nil { + t.Fatalf("unable to create payment stream for alice: %v", err) + } + + sendReq := &lnrpc.SendRequest{ + PaymentHash: bobPaymentHashes[numInvoices-1], + Dest: net.Bob.PubKey[:], + Amt: paymentAmt, + } + + if err := alicePayStream.Send(sendReq); err != nil { + t.Fatalf("unable to send payment: "+ + "stream has been closed: %v", err) + } + + if resp, err := alicePayStream.Recv(); err != nil { + t.Fatalf("payment stream has been closed: %v", err) + } else if resp.PaymentError != "" { + t.Fatalf("unable to send the payment: %v", resp.PaymentError) + } + + time.Sleep(time.Second) + + // Finally, immediately close the channel. This function will also + // block until the channel is closed and will additionally assert the + // relevant channel closing post conditions. + ctxt, _ = context.WithTimeout(ctxb, timeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) +} + type testCase struct { name string test func(net *networkHarness, t *harnessTest) @@ -4003,6 +4167,12 @@ var testsCases = []*testCase{ test: testBidirectionalAsyncPayments, }, { + name: "channel reestablishment", + test: testChannelReestablishment, + }, + { + // TODO(roasbeef): test always needs to be last as Bob's state + // is borked since we trick him into attempting to cheat Alice? name: "revoked uncooperative close retribution", test: testRevokedCloseRetribution, }, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b9299544e..3f846f97d 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4,6 +4,7 @@ import ( "bytes" "container/list" "crypto/sha256" + "errors" "fmt" "runtime" "sort" @@ -2428,6 +2429,132 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig return sig, htlcSigs, nil } +// ReceiveReestablish is used to handle the remote channel reestablish message +// and generate the set of updates which are have to be sent to remote side +// to synchronize the states of the channels. +func (lc *LightningChannel) ReceiveReestablish(msg *lnwire.ChannelReestablish) ( + []lnwire.Message, error) { + + lc.Lock() + defer lc.Unlock() + var updates []lnwire.Message + + // As far we store on last commitment transaction we should rely on the + // height of the commitment transaction in order to calculate the length. + numberRemoteCommitments := lc.remoteCommitChain.tip().height + 1 + + // Number of the revocations might be calculated as the height of the + // commitment transactions which will be revoked next minus one. And plus + // one because height starts from zero. + numberRemoteRevocations := lc.localCommitChain.tail().height - 1 + 1 + + revocationsnumberDiff := msg.NextRemoteRevocationNumber - numberRemoteRevocations + if revocationsnumberDiff == 0 { + // If remote side expects as receive revocation which we already + // consider as last, than it means that they aren't received our + // last revocation message. + revocationMsg, err := lc.generateRevocation(lc.currentHeight - 1) + if err != nil { + return nil, err + } + updates = append(updates, revocationMsg) + } else if revocationsnumberDiff < 0 { + // Remote node claims that it received the revoke_and_ack message + // which we did not send. + return nil, errors.New("remote side claims that it haven't received " + + "acked revoke and ack message") + } + + commitmentChainDiff := msg.NextLocalCommitmentNumber - numberRemoteCommitments + if commitmentChainDiff == 0 { + // If remote side expects as receive commitment which we already + // consider as last, than it means that they aren't received our + // last commit sig message. + commitment := lc.remoteCommitChain.tip() + chanID := lnwire.NewChanIDFromOutPoint(&lc.channelState.FundingOutpoint) + for _, htlc := range commitment.outgoingHTLCs { + // If htlc is included in the local commitment chain (have been + // included by remote side) or htlc is included in remote chain, but + // not in the last commimemnt transaction than we should skip it, + // because we need resend only updates which haven't been received + // by remotes side. + if htlc.addCommitHeightLocal != 0 || + (htlc.addCommitHeightLocal != 0 && + htlc.addCommitHeightLocal <= commitment.height) { + continue + } + + switch htlc.EntryType { + case Add: + var onionBlob [lnwire.OnionPacketSize]byte + copy(onionBlob[:], htlc.Payload) + updates = append(updates, &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: htlc.Index, + Expiry: htlc.Timeout, + Amount: htlc.Amount, + PaymentHash: htlc.RHash, + OnionBlob: onionBlob, + }) + case Fail: + updates = append(updates, &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: htlc.Index, + Reason: lnwire.OpaqueReason([]byte{}), + }) + case Settle: + updates = append(updates, &lnwire.UpdateFufillHTLC{ + ChanID: chanID, + ID: htlc.Index, + PaymentPreimage: htlc.RPreimage, + }) + } + } + + // Generate last sent commit sig message by signing the transaction and + // creating the signature. + lc.signDesc.SigHashes = txscript.NewTxSigHashes(commitment.txn) + sig, err := lc.signer.SignOutputRaw(commitment.txn, lc.signDesc) + if err != nil { + return nil, err + } + + commitSig, err := btcec.ParseSignature(sig, btcec.S256()) + if err != nil { + return nil, err + } + updates = append(updates, &lnwire.CommitSig{ + ChanID: chanID, + CommitSig: commitSig, + }) + + } else if commitmentChainDiff < 0 { + // Remote node claims that it received the commit sig message which we + // did not send. + return nil, errors.New("remote side claims that it haven't received " + + "acked commit sig message") + } + + return updates, nil +} + +// LastCounters returns the historical length of the local commimemnt +// transaction chain and the historical number of the the revocked commiment +// transactions, whicha are needed in order to generate the channel +// reestablish message. +func (lc *LightningChannel) LastCounters() (uint64, uint64) { + // As far we store on last commitment transaction we should rely on the + // height of the commitment transaction in order to calculate the length. + numberLocalCommitments := lc.localCommitChain.tip().height + 1 + + // Number of the revocations might be calculated as the height of the + // commitment transactions which will be revoked next minus one. And plus + // one because height starts from zero. + numberRemoteRevocations := lc.remoteCommitChain.tail().height - 1 + 1 + + return numberLocalCommitments, numberRemoteRevocations +} + // validateCommitmentSanity is used to validate that on current state the commitment // transaction is valid in terms of propagating it over Bitcoin network, and // also that all outputs are meet Bitcoin spec requirements and they are @@ -2782,43 +2909,13 @@ func (lc *LightningChannel) RevokeCurrentCommitment() (*lnwire.RevokeAndAck, err lc.Lock() defer lc.Unlock() - // Now that we've accept a new state transition, we send the remote - // party the revocation for our current commitment state. - revocationMsg := &lnwire.RevokeAndAck{} - commitSecret, err := lc.channelState.RevocationProducer.AtIndex( - lc.currentHeight, - ) + revocationMsg, err := lc.generateRevocation(lc.currentHeight) if err != nil { return nil, err } - copy(revocationMsg.Revocation[:], commitSecret[:]) - // Along with this revocation, we'll also send the _next_ commitment - // point that the remote party should use to create our next commitment - // transaction. We use a +2 here as we already gave them a look ahead - // of size one after the FundingLocked message was sent: - // - // 0: current revocation, 1: their "next" revocation, 2: this revocation - // - // We're revoking the current revocation. Once they receive this - // message they'll set the "current" revocation for us to their stored - // "next" revocation, and this revocation will become their new "next" - // revocation. - // - // Put simply in the window slides to the left by one. - nextCommitSecret, err := lc.channelState.RevocationProducer.AtIndex( - lc.currentHeight + 2, - ) - if err != nil { - return nil, err - } - revocationMsg.NextRevocationKey = ComputeCommitmentPoint( - nextCommitSecret[:], - ) - - walletLog.Tracef("ChannelPoint(%v): revoking height=%v, now at height=%v", - lc.channelState.FundingOutpoint, lc.localCommitChain.tail().height, - lc.currentHeight+1) + walletLog.Tracef("ChannelPoint(%v): revoking height=%v, now at height=%v", lc.channelState.FundingOutpoint, + lc.localCommitChain.tail().height, lc.currentHeight+1) // Advance our tail, as we've revoked our previous state. lc.localCommitChain.advanceTail() @@ -3860,6 +3957,47 @@ func (lc *LightningChannel) ReceiveUpdateFee(feePerKw btcutil.Amount) error { return nil } +// generateRevocation generate lnwire revocation message by the given height +// and revocation edge. +func (lc *LightningChannel) generateRevocation(height uint64) (*lnwire.RevokeAndAck, + error) { + + // Now that we've accept a new state transition, we send the remote + // party the revocation for our current commitment state. + revocationMsg := &lnwire.RevokeAndAck{} + commitSecret, err := lc.channelState.RevocationProducer.AtIndex(height) + if err != nil { + return nil, err + } + copy(revocationMsg.Revocation[:], commitSecret[:]) + + // Along with this revocation, we'll also send the _next_ commitment + // point that the remote party should use to create our next commitment + // transaction. We use a +2 here as we already gave them a look ahead + // of size one after the FundingLocked message was sent: + // + // 0: current revocation, 1: their "next" revocation, 2: this revocation + // + // We're revoking the current revocation. Once they receive this + // message they'll set the "current" revocation for us to their stored + // "next" revocation, and this revocation will become their new "next" + // revocation. + // + // Put simply in the window slides to the left by one. + nextCommitSecret, err := lc.channelState.RevocationProducer.AtIndex( + height + 2, + ) + if err != nil { + return nil, err + } + + revocationMsg.NextRevocationKey = ComputeCommitmentPoint(nextCommitSecret[:]) + revocationMsg.ChanID = lnwire.NewChanIDFromOutPoint( + &lc.channelState.FundingOutpoint) + + return revocationMsg, nil +} + // CreateCommitTx creates a commitment transaction, spending from specified // funding output. The commitment transaction contains two outputs: one paying // to the "owner" of the commitment transaction which can be spent after a diff --git a/peer.go b/peer.go index 2057b3cc8..80e16ef68 100644 --- a/peer.go +++ b/peer.go @@ -399,6 +399,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error { Switch: p.server.htlcSwitch, FwrdingPolicy: *forwardingPolicy, BlockEpochs: blockEpoch, + SyncStates: true, } link := htlcswitch.NewChannelLink(linkCfg, lnChan, uint32(currentHeight)) @@ -745,6 +746,9 @@ out: case *lnwire.UpdateFee: isChanUpdate = true targetChan = msg.ChanID + case *lnwire.ChannelReestablish: + isChanUpdate = true + targetChan = msg.ChanID case *lnwire.ChannelUpdate, *lnwire.ChannelAnnouncement, @@ -1261,6 +1265,7 @@ out: Switch: p.server.htlcSwitch, FwrdingPolicy: p.server.cc.routingPolicy, BlockEpochs: blockEpoch, + SyncStates: false, } link := htlcswitch.NewChannelLink(linkConfig, newChan, uint32(currentHeight))