From 72e8b900dbf48f135d3afbb30733202dfe9fc89a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 7 Nov 2023 11:42:42 +0200 Subject: [PATCH] multi: let some netann funcs use lnwire.ChannelUpdate ...interface instead of ChannelUpdate1. --- discovery/chan_series.go | 6 +- discovery/gossiper.go | 50 +++++++++---- discovery/gossiper_test.go | 6 +- discovery/syncer.go | 19 +++-- discovery/syncer_test.go | 10 +-- netann/chan_status_manager.go | 6 +- netann/chan_status_manager_test.go | 43 +++++------ netann/channel_update.go | 75 ++++++++++++++++--- peer/test_utils.go | 2 +- routing/router_test.go | 115 +++++++++++++++++++---------- server.go | 2 +- 11 files changed, 227 insertions(+), 107 deletions(-) diff --git a/discovery/chan_series.go b/discovery/chan_series.go index d91396a54..93a37f070 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -61,7 +61,7 @@ type ChannelGraphTimeSeries interface { // specified short channel ID. If no channel updates are known for the // channel, then an empty slice will be returned. FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) } @@ -332,7 +332,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) { chanInfo, e1, e2, err := c.graph.FetchChannelEdgesByID( shortChanID.ToUint64(), @@ -341,7 +341,7 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, return nil, err } - chanUpdates := make([]*lnwire.ChannelUpdate1, 0, 2) + chanUpdates := make([]lnwire.ChannelUpdate, 0, 2) if e1 != nil { chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 2672e8227..785f66fb3 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -1826,7 +1826,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( var defaultAlias lnwire.ShortChannelID foundAlias, _ := d.cfg.GetAlias(chanID) if foundAlias != defaultAlias { - chanUpdate.ShortChannelID = foundAlias + chanUpdate.SetSCID(foundAlias) err := d.cfg.SignAliasUpdate(chanUpdate) if err != nil { @@ -1846,7 +1846,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( log.Errorf("Unable to reliably send %v for "+ "channel=%v to peer=%x: %v", chanUpdate.MsgType(), - chanUpdate.ShortChannelID, + chanUpdate.SCID(), remotePubKey, err) } continue @@ -2244,23 +2244,17 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, - edgePolicy models.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate1, error) { + edge models.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + lnwire.ChannelUpdate, error) { // Parse the unsigned edge into a channel update. chanUpdate, err := netann.UnsignedChannelUpdateFromEdge( - edgeInfo.GetChainHash(), edgePolicy, + edgeInfo.GetChainHash(), edge, ) if err != nil { return nil, nil, err } - edge, ok := edgePolicy.(*models.ChannelEdgePolicy1) - if !ok { - return nil, nil, fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", edgePolicy) - } - // We'll generate a new signature over a digest of the channel // announcement itself and update the timestamp to ensure it propagate. err = netann.SignChannelUpdate( @@ -2273,8 +2267,25 @@ func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, // Next, we'll set the new signature in place, and update the reference // in the backing slice. - edge.LastUpdate = time.Unix(int64(chanUpdate.Timestamp), 0) - edge.SigBytes = chanUpdate.Signature.ToSignatureBytes() + switch e := edge.(type) { + case *models.ChannelEdgePolicy1: + chanUpd, ok := chanUpdate.(*lnwire.ChannelUpdate1) + if !ok { + return nil, nil, fmt.Errorf("wanted chan update 1") + } + + e.LastUpdate = time.Unix(int64(chanUpd.Timestamp), 0) + e.SigBytes = chanUpd.Signature.ToSignatureBytes() + + case *models.ChannelEdgePolicy2: + chanUpd, ok := chanUpdate.(*lnwire.ChannelUpdate2) + if !ok { + return nil, nil, fmt.Errorf("wanted chan update 2") + } + + e.BlockHeight = chanUpd.BlockHeight + e.Signature = chanUpd.Signature + } // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. @@ -2302,6 +2313,10 @@ func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, if err != nil { return nil, nil, err } + + case *models.ChannelEdgeInfo2: + chanAnn = chanAnn2FromEdgeInfo2(info) + default: return nil, nil, fmt.Errorf("unhandled "+ "implementation of models.ChannelEdgeInfo: "+ @@ -2356,6 +2371,15 @@ func chanAnn1FromEdgeInfo1(info *models.ChannelEdgeInfo1) ( return chanAnn, nil } +func chanAnn2FromEdgeInfo2( + info *models.ChannelEdgeInfo2) *lnwire.ChannelAnnouncement2 { + + chanAnn := info.ChannelAnnouncement2 + chanAnn.Signature = info.Signature + + return &chanAnn +} + // SyncManager returns the gossiper's SyncManager instance. func (d *AuthenticatedGossiper) SyncManager() *SyncManager { return d.syncMgr diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 5f0497eb6..84a77263d 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -266,9 +266,11 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( }, nil, nil, channeldb.ErrZombieEdge } + chanInfoCP := chanInfo.Copy() + edges := r.edges[chanID.ToUint64()] if len(edges) == 0 { - return chanInfo, nil, nil, nil + return chanInfoCP, nil, nil, nil } var edge1 models.ChannelEdgePolicy @@ -281,7 +283,7 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( edge2 = edges[1] } - return chanInfo, edge1, edge2, nil + return chanInfoCP, edge1, edge2, nil } func (r *mockGraphSource) FetchLightningNode( diff --git a/discovery/syncer.go b/discovery/syncer.go index 886aa4be0..a3a2945fa 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1413,16 +1413,16 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // to quickly check if we should forward a chan ann, based on the known // channel updates for a channel. chanUpdateIndex := make( - map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate1, + map[lnwire.ShortChannelID][]lnwire.ChannelUpdate, ) for _, msg := range msgs { - chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate1) + chanUpdate, ok := msg.msg.(lnwire.ChannelUpdate) if !ok { continue } - chanUpdateIndex[chanUpdate.ShortChannelID] = append( - chanUpdateIndex[chanUpdate.ShortChannelID], chanUpdate, + chanUpdateIndex[chanUpdate.SCID()] = append( + chanUpdateIndex[chanUpdate.SCID()], chanUpdate, ) } @@ -1475,7 +1475,16 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { } for _, chanUpdate := range chanUpdates { - if passesFilter(chanUpdate.Timestamp) { + update, ok := chanUpdate.(*lnwire.ChannelUpdate1) + if !ok { + log.Errorf("expected "+ + "*lnwire.ChannelUpdate1, "+ + "got: %T", update) + + continue + } + + if passesFilter(update.Timestamp) { msgsToSend = append(msgsToSend, msg) break } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 0ee635a0f..f17658057 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -52,7 +52,7 @@ type mockChannelGraphTimeSeries struct { annResp chan []lnwire.Message updateReq chan lnwire.ShortChannelID - updateResp chan []*lnwire.ChannelUpdate1 + updateResp chan []lnwire.ChannelUpdate } func newMockChannelGraphTimeSeries( @@ -74,7 +74,7 @@ func newMockChannelGraphTimeSeries( annResp: make(chan []lnwire.Message, 1), updateReq: make(chan lnwire.ShortChannelID, 1), - updateResp: make(chan []*lnwire.ChannelUpdate1, 1), + updateResp: make(chan []lnwire.ChannelUpdate, 1), } } @@ -149,7 +149,7 @@ func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, return <-m.annResp, nil } func (m *mockChannelGraphTimeSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) { m.updateReq <- shortChanID @@ -369,8 +369,8 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { } // If so, then we'll send back the missing update. - chanSeries.updateResp <- []*lnwire.ChannelUpdate1{ - { + chanSeries.updateResp <- []lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), }, diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 9a7b30e3a..a75db58dd 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -63,7 +63,7 @@ type ChanStatusConfig struct { // ApplyChannelUpdate processes new ChannelUpdates signed by our node by // updating our local routing table and broadcasting the update to our // peers. - ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint, + ApplyChannelUpdate func(lnwire.ChannelUpdate, *wire.OutPoint, bool) error // DB stores the set of channels that are to be monitored. @@ -658,7 +658,7 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate1, bool, error) { + lnwire.ChannelUpdate, bool, error) { // Get the edge info and policies for this channel from the graph. info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op) @@ -689,7 +689,7 @@ func (m *ChanStatusManager) loadInitialChanState( // Determine the channel's starting status by inspecting the disable bit // on last announcement we sent out. var initialStatus ChanStatus - if lastUpdate.ChannelFlags&lnwire.ChanUpdateDisabled == 0 { + if !lastUpdate.IsDisabled() { initialStatus = ChanStatusEnabled } else { initialStatus = ChanStatusDisabled diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index f0025d7eb..e5e0e48b9 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -136,7 +136,7 @@ type mockGraph struct { chanPols2 map[wire.OutPoint]*models.ChannelEdgePolicy1 sidToCid map[lnwire.ShortChannelID]wire.OutPoint - updates chan *lnwire.ChannelUpdate1 + updates chan lnwire.ChannelUpdate } func newMockGraph(t *testing.T, numChannels int, @@ -148,7 +148,7 @@ func newMockGraph(t *testing.T, numChannels int, chanPols1: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), chanPols2: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), - updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), + updates: make(chan lnwire.ChannelUpdate, 2*numChannels), } for i := 0; i < numChannels; i++ { @@ -187,46 +187,47 @@ func (g *mockGraph) FetchChannelEdgesByOutpoint( return info, pol1, pol2, nil } -func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, +func (g *mockGraph) ApplyChannelUpdate(update lnwire.ChannelUpdate, op *wire.OutPoint, private bool) error { g.mu.Lock() defer g.mu.Unlock() - outpoint, ok := g.sidToCid[update.ShortChannelID] + outpoint, ok := g.sidToCid[update.SCID()] if !ok { return fmt.Errorf("unknown short channel id: %v", - update.ShortChannelID) + update.SCID()) } pol1 := g.chanPols1[outpoint] pol2 := g.chanPols2[outpoint] - // Determine which policy we should update by making the flags on the // policies and updates, and seeing which match up. var update1 bool + switch { - case update.ChannelFlags&lnwire.ChanUpdateDirection == - pol1.ChannelFlags&lnwire.ChanUpdateDirection: + case update.IsNode1() == pol1.IsNode1(): update1 = true - case update.ChannelFlags&lnwire.ChanUpdateDirection == - pol2.ChannelFlags&lnwire.ChanUpdateDirection: + case update.IsNode1() == pol2.IsNode1(): update1 = false default: return fmt.Errorf("unable to find policy to update") } - timestamp := time.Unix(int64(update.Timestamp), 0) + upd, ok := update.(*lnwire.ChannelUpdate1) + if !ok { + return fmt.Errorf("expected channel update 1") + } + timestamp := time.Unix(int64(upd.Timestamp), 0) policy := &models.ChannelEdgePolicy1{ - ChannelID: update.ShortChannelID.ToUint64(), - ChannelFlags: update.ChannelFlags, + ChannelID: upd.ShortChannelID.ToUint64(), + ChannelFlags: upd.ChannelFlags, LastUpdate: timestamp, SigBytes: testSigBytes, } - if update1 { g.chanPols1[outpoint] = policy } else { @@ -517,23 +518,23 @@ func (h *testHarness) assertUpdates(channels []*channeldb.OpenChannel, for { select { case upd := <-h.graph.updates: + scid := upd.SCID() + // Assert that the received short channel id is one that // we expect. If no updates were expected, this will // always fail on the first update received. - if _, ok := expSids[upd.ShortChannelID]; !ok { + if _, ok := expSids[scid]; !ok { h.t.Fatalf("received update for unexpected "+ - "short chan id: %v", upd.ShortChannelID) + "short chan id: %v", scid) } // Assert that the disabled bit is set properly. - enabled := upd.ChannelFlags&lnwire.ChanUpdateDisabled != - lnwire.ChanUpdateDisabled - if expEnabled != enabled { + if expEnabled != !upd.IsDisabled() { h.t.Fatalf("expected enabled: %v, actual: %v", - expEnabled, enabled) + expEnabled, !upd.IsDisabled()) } - recvdSids[upd.ShortChannelID] = struct{}{} + recvdSids[scid] = struct{}{} case <-timeout: // Time is up, assert that the correct number of unique diff --git a/netann/channel_update.go b/netann/channel_update.go index a947aca43..08d83dfe2 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -128,9 +128,9 @@ func ExtractChannelUpdate(ownerPubKey []byte, *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. - owner := func(edge *models.ChannelEdgePolicy1) []byte { + owner := func(edge models.ChannelEdgePolicy) []byte { var pubKey *btcec.PublicKey - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + if edge.IsNode1() { pubKey, _ = info.NodeKey1() } else { pubKey, _ = info.NodeKey2() @@ -146,14 +146,20 @@ func ExtractChannelUpdate(ownerPubKey []byte, // Extract the channel update from the policy we own, if any. for _, edge := range policies { - e, ok := edge.(*models.ChannelEdgePolicy1) - if !ok { - return nil, fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", edge) - } + if edge != nil && bytes.Equal(ownerPubKey, owner(edge)) { + update, err := ChannelUpdateFromEdge(info, edge) + if err != nil { + return nil, err + } - if edge != nil && bytes.Equal(ownerPubKey, owner(e)) { - return ChannelUpdateFromEdge(info, edge) + chanUpd1, ok := update.(*lnwire.ChannelUpdate1) + if !ok { + return nil, fmt.Errorf("expected "+ + "*lnwire.ChannelUpdate1, got: %T", + chanUpd1) + } + + return chanUpd1, nil } } @@ -163,12 +169,15 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the // given edge info and policy. func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, - policy models.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { switch p := policy.(type) { case *models.ChannelEdgePolicy1: return unsignedChanPolicy1ToUpdate(chainHash, p), nil + case *models.ChannelEdgePolicy2: + return unsignedChanPolicy2ToUpdate(chainHash, p), nil + default: return nil, fmt.Errorf("unhandled implementation of the "+ "models.ChanelEdgePolicy interface: %T", policy) @@ -193,10 +202,36 @@ func unsignedChanPolicy1ToUpdate(chainHash chainhash.Hash, } } +func unsignedChanPolicy2ToUpdate(chainHash chainhash.Hash, + policy *models.ChannelEdgePolicy2) *lnwire.ChannelUpdate2 { + + update := &lnwire.ChannelUpdate2{ + ShortChannelID: policy.ShortChannelID, + BlockHeight: policy.BlockHeight, + DisabledFlags: policy.DisabledFlags, + SecondPeer: policy.SecondPeer, + CLTVExpiryDelta: policy.CLTVExpiryDelta, + HTLCMinimumMsat: policy.HTLCMinimumMsat, + HTLCMaximumMsat: policy.HTLCMaximumMsat, + FeeBaseMsat: policy.FeeBaseMsat, + FeeProportionalMillionths: policy.FeeProportionalMillionths, + ExtraOpaqueData: policy.ExtraOpaqueData, + } + update.ChainHash.Val = chainHash + + return update +} + // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given // edge info and policy. func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, - policy models.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { + + return signedChannelUpdateFromEdge(info.GetChainHash(), policy) +} + +func signedChannelUpdateFromEdge(chainHash chainhash.Hash, + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { switch p := policy.(type) { case *models.ChannelEdgePolicy1: @@ -210,7 +245,23 @@ func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, return nil, err } - update := unsignedChanPolicy1ToUpdate(info.GetChainHash(), p) + update := unsignedChanPolicy1ToUpdate(chainHash, p) + update.Signature = s + + return update, nil + + case *models.ChannelEdgePolicy2: + sig, err := p.Signature.ToSignature() + if err != nil { + return nil, err + } + + s, err := lnwire.NewSigFromSignature(sig) + if err != nil { + return nil, err + } + + update := unsignedChanPolicy2ToUpdate(chainHash, p) update.Signature = s return update, nil diff --git a/peer/test_utils.go b/peer/test_utils.go index 0575acca5..09375c652 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -611,7 +611,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate1, + ApplyChannelUpdate: func(lnwire.ChannelUpdate, *wire.OutPoint, bool) error { return nil diff --git a/routing/router_test.go b/routing/router_test.go index cab0bff27..8a0b88c7f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -15,8 +15,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" @@ -28,6 +30,7 @@ import ( "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -146,7 +149,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, MissionControl: mc, } - graphBuilder := newMockGraphBuilder(graphInstance.graph) + graphBuilder := newMockGraphBuilder(t, graphInstance.graph) router, err := New(Config{ SelfNode: sourceNode.PubKeyBytes, @@ -221,16 +224,50 @@ func createTestCtxFromFile(t *testing.T, // Add valid signature to channel update simulated as error received from the // network. func signErrChanUpdate(t *testing.T, key *btcec.PrivateKey, - errChanUpdate *lnwire.ChannelUpdate1) { + errChanUpdate lnwire.ChannelUpdate) { - chanUpdateMsg, err := errChanUpdate.DataToSign() - require.NoError(t, err, "failed to retrieve data to sign") + signer := &mockSigner{key: key} + err := netann.SignChannelUpdate( + signer, keychain.KeyLocator{}, errChanUpdate, + ) + require.NoError(t, err) +} - digest := chainhash.DoubleHashB(chanUpdateMsg) - sig := ecdsa.Sign(key, digest) +type mockSigner struct { + key *btcec.PrivateKey + keychain.MessageSignerRing +} - errChanUpdate.Signature, err = lnwire.NewSigFromSignature(sig) - require.NoError(t, err, "failed to create new signature") +func (s *mockSigner) SignMessage(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool) (*ecdsa.Signature, error) { + + digest := chainhash.DoubleHashB(msg) + sig := ecdsa.Sign(s.key, digest) + + return sig, nil +} + +func (s *mockSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool, taprootTweak, tag []byte) (*schnorr.Signature, + error) { + + var digest []byte + switch { + case len(tag) > 0: + taggedHash := chainhash.TaggedHash(tag, msg) + digest = taggedHash[:] + case doubleHash: + digest = chainhash.DoubleHashB(msg) + default: + digest = chainhash.HashB(msg) + } + + privKey := s.key + if len(taprootTweak) > 0 { + privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak) + } + + return schnorr.Sign(privKey, digest) } // TestFindRoutesWithFeeLimit asserts that routes found by the FindRoutes method @@ -613,16 +650,16 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { ) require.NoError(t, err, "unable to fetch chan id") - edgeUpdToFail, ok := edgeUpdateToFail.(*models.ChannelEdgePolicy1) - require.True(t, ok) - errChanUpdate, err := netann.UnsignedChannelUpdateFromEdge( - chainhash.Hash{}, edgeUpdToFail, + chainhash.Hash{}, edgeUpdateToFail, ) require.NoError(t, err) signErrChanUpdate(t, ctx.privKeys["songoku"], errChanUpdate) + chanUpd, ok := errChanUpdate.(*lnwire.ChannelUpdate1) + require.True(t, ok) + // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to Son goku. This will be a fee related error, so // it should only cause the edge to be pruned after the second attempt. @@ -636,15 +673,17 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { roasbeefSongokuChanID, ) if firstHop == roasbeefSongoku { - return [32]byte{}, htlcswitch.NewForwardingError( - // Within our error, we'll add a - // channel update which is meant to - // reflect the new fee schedule for the - // node/channel. - &lnwire.FailFeeInsufficient{ - Update: *errChanUpdate, - }, 1, - ) + if firstHop == roasbeefSongoku { + return [32]byte{}, htlcswitch.NewForwardingError( + // Within our error, we'll add a + // channel update which is meant to + // reflect the new fee schedule for the + // node/channel. + &lnwire.FailFeeInsufficient{ + Update: *chanUpd, + }, 1, + ) + } } return preImage, nil @@ -961,6 +1000,9 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { ) require.NoError(t, err) + chanUpd, ok := errChanUpdate.(*lnwire.ChannelUpdate1) + require.True(t, ok) + // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to son goku. Since this is a time lock related // error, we should fail the payment flow all together, as Goku is the @@ -970,7 +1012,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ - Update: *errChanUpdate, + Update: *chanUpd, }, 1, ) } @@ -1018,7 +1060,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ - Update: *errChanUpdate, + Update: *chanUpd, }, 1, ) } @@ -2940,13 +2982,15 @@ func createDummyLightningPayment(t *testing.T, } type mockGraphBuilder struct { + t *testing.T rejectUpdate bool - updateEdge func(update *models.ChannelEdgePolicy1) error + updateEdge func(update models.ChannelEdgePolicy) error } -func newMockGraphBuilder(graph graph.DB) *mockGraphBuilder { +func newMockGraphBuilder(t *testing.T, graph graph.DB) *mockGraphBuilder { return &mockGraphBuilder{ - updateEdge: func(update *models.ChannelEdgePolicy1) error { + t: t, + updateEdge: func(update models.ChannelEdgePolicy) error { return graph.UpdateEdgePolicy(update) }, } @@ -2956,26 +3000,15 @@ func (m *mockGraphBuilder) setNextReject(reject bool) { m.rejectUpdate = reject } -func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { +func (m *mockGraphBuilder) ApplyChannelUpdate(msg lnwire.ChannelUpdate) bool { if m.rejectUpdate { return false } - err := m.updateEdge(&models.ChannelEdgePolicy1{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - ExtraOpaqueData: msg.ExtraOpaqueData, - }) + policy, err := models.EdgePolicyFromUpdate(msg) + require.NoError(m.t, err) - return err == nil + return m.updateEdge(policy) == nil } type mockChain struct { diff --git a/server.go b/server.go index 527ccfcb4..14939f53b 100644 --- a/server.go +++ b/server.go @@ -4828,7 +4828,7 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( // applyChannelUpdate applies the channel update to the different sub-systems of // the server. The useAlias boolean denotes whether or not to send an alias in // place of the real SCID. -func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, +func (s *server) applyChannelUpdate(update lnwire.ChannelUpdate, op *wire.OutPoint, useAlias bool) error { var (