From 07466c4f8c1290d6599a2e6ca5d1808b8dae7106 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 4 Apr 2024 11:10:03 +0800 Subject: [PATCH] multi: query circuit map inside contractcourt This commit adds a new config method `QueryIncomingCircuit` that can be used to query the payment's incoming circuit for giving its outgoing circuit key. --- contractcourt/chain_arbitrator.go | 100 ++++++++++++------ contractcourt/chain_arbitrator_test.go | 6 ++ contractcourt/channel_arbitrator.go | 12 +-- contractcourt/channel_arbitrator_test.go | 8 +- contractcourt/commit_sweep_resolver_test.go | 6 ++ .../htlc_incoming_contest_resolver_test.go | 5 + .../htlc_outgoing_contest_resolver_test.go | 6 ++ contractcourt/htlc_success_resolver_test.go | 6 ++ contractcourt/htlc_timeout_resolver_test.go | 4 + server.go | 21 +++- 10 files changed, 130 insertions(+), 44 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index e9f68c66a..e08618307 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -3,7 +3,6 @@ package contractcourt import ( "errors" "fmt" - "math" "sync" "sync/atomic" "time" @@ -14,6 +13,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -206,6 +206,17 @@ type ChainArbitratorConfig struct { // Budget is the configured budget for the arbitrator. Budget BudgetConfig + + // QueryIncomingCircuit is used to find the outgoing HTLC's + // corresponding incoming HTLC circuit. It queries the circuit map for + // a given outgoing circuit key and returns the incoming circuit key. + // + // TODO(yy): this is a hacky way to get around the cycling import issue + // as we cannot import `htlcswitch` here. A proper way is to define an + // interface here that asks for method `LookupOpenCircuit`, + // meanwhile, turn `PaymentCircuit` into an interface or bring it to a + // lower package. + QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -389,9 +400,11 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, return chanStateDB.FetchHistoricalChannel(&chanPoint) }, FindOutgoingHTLCDeadline: func( - rHash chainhash.Hash) fn.Option[int32] { + htlc channeldb.HTLC) fn.Option[int32] { - return c.FindOutgoingHTLCDeadline(chanPoint, rHash) + return c.FindOutgoingHTLCDeadline( + channel.ShortChanID(), htlc, + ) }, } @@ -612,9 +625,11 @@ func (c *ChainArbitrator) Start() error { return chanStateDB.FetchHistoricalChannel(&chanPoint) }, FindOutgoingHTLCDeadline: func( - rHash chainhash.Hash) fn.Option[int32] { + htlc channeldb.HTLC) fn.Option[int32] { - return c.FindOutgoingHTLCDeadline(chanPoint, rHash) + return c.FindOutgoingHTLCDeadline( + closeChanInfo.ShortChanID, htlc, + ) }, } chanLog, err := newBoltArbitratorLog( @@ -1224,22 +1239,48 @@ func (c *ChainArbitrator) SubscribeChannelEvents( // by the timeout height of its corresponding incoming HTLC - this is the // expiry height the that remote peer can spend his/her outgoing HTLC via the // timeout path. -func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint, - rHash chainhash.Hash) fn.Option[int32] { +func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, + outgoingHTLC channeldb.HTLC) fn.Option[int32] { - // minRefundTimeout tracks the minimal refund timeout found using the - // rHash. It's possible that we find multiple HTLCs living in different - // channels sharing the same rHash if an MPP is routed by us. In this - // case, we'll use the smallest refund timeout as the deadline. - // - // TODO(yy): can instead query the circuit map to find the exact HTLC. - minRefundTimeout := uint32(math.MaxInt32) + // Find the outgoing HTLC's corresponding incoming HTLC in the circuit + // map. + rHash := outgoingHTLC.RHash + circuit := models.CircuitKey{ + ChanID: scid, + HtlcID: outgoingHTLC.HtlcIndex, + } + incomingCircuit := c.cfg.QueryIncomingCircuit(circuit) - // Iterate over all active channels to find the HTLC with the matching - // rHash. + // If there's no incoming circuit found, we will use the default + // deadline. + if incomingCircuit == nil { + log.Warnf("ChannelArbitrator(%v): incoming circuit key not "+ + "found for rHash=%x, using default deadline instead", + scid, rHash) + + return fn.None[int32]() + } + + // If this is a locally initiated HTLC, it means we are the first hop. + // In this case, we can relax the deadline. + if incomingCircuit.ChanID.IsDefault() { + log.Infof("ChannelArbitrator(%v): using default deadline for "+ + "locally initiated HTLC for rHash=%x", scid, rHash) + + return fn.None[int32]() + } + + log.Debugf("Found incoming circuit %v for rHash=%x using outgoing "+ + "circuit %v", incomingCircuit, rHash, circuit) + + c.Lock() + defer c.Unlock() + + // Iterate over all active channels to find the incoming HTLC specified + // by its circuit key. for cp, channelArb := range c.activeChannels { - // Skip the targeted channel as the incoming HTLC is not here. - if cp == chanPoint { + // Skip if the SCID doesn't match. + if channelArb.cfg.ShortChanID != incomingCircuit.ChanID { continue } @@ -1247,32 +1288,25 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint, // HTLC. for _, htlcs := range channelArb.activeHTLCs { for _, htlc := range htlcs.incomingHTLCs { - if htlc.RHash != rHash { + // Skip if the index doesn't match. + if htlc.HtlcIndex != incomingCircuit.HtlcID { continue } log.Debugf("ChannelArbitrator(%v): found "+ "incoming HTLC in channel=%v using "+ - "rHash=%v, refundTimeout=%v", chanPoint, + "rHash=%x, refundTimeout=%v", scid, cp, rHash, htlc.RefundTimeout) - // Update the value if it's smaller. - if minRefundTimeout > htlc.RefundTimeout { - minRefundTimeout = htlc.RefundTimeout - } + return fn.Some(int32(htlc.RefundTimeout)) } } } - // Return the refund timeout value if found. - if minRefundTimeout != math.MaxInt32 { - return fn.Some(int32(minRefundTimeout)) - } - - // If there's no incoming HTLC found, it means we are the first hop. In - // this case, we can relax the deadline. - log.Infof("ChannelArbitrator(%v): incoming HTLC not found for "+ - "rHash=%v, using default deadline instead", chanPoint, rHash) + // If there's no incoming HTLC found, yet we have the incoming circuit, + // something is wrong - in this case, we return the none deadline. + log.Errorf("ChannelArbitrator(%v): incoming HTLC not found for "+ + "rHash=%x, using default deadline instead", scid, rHash) return fn.None[int32]() } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index d0a476a84..36f6dad18 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" @@ -172,6 +173,11 @@ func TestResolveContract(t *testing.T) { }, Clock: clock.NewDefaultClock(), Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, } chainArb := NewChainArbitrator( chainArbCfg, db, diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8b3e193ff..7f582b333 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -170,7 +170,7 @@ type ChannelArbitratorConfig struct { // deadline is defined by the timeout height of its corresponding // incoming HTLC - this is the expiry height the that remote peer can // spend his/her outgoing HTLC via the timeout path. - FindOutgoingHTLCDeadline func(rHash chainhash.Hash) fn.Option[int32] + FindOutgoingHTLCDeadline func(htlc channeldb.HTLC) fn.Option[int32] ChainArbitratorConfig } @@ -764,7 +764,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // the resolver with the expiry block height of its // corresponding incoming HTLC. if !htlc.Incoming { - deadline := c.cfg.FindOutgoingHTLCDeadline(htlc.RHash) + deadline := c.cfg.FindOutgoingHTLCDeadline(*htlc) htlcResolver.SupplementDeadline(deadline) } } @@ -2421,9 +2421,7 @@ func (c *ChannelArbitrator) prepContractResolutions( // supplement the resolver with the expiry // block height of its corresponding incoming // HTLC. - deadline := c.cfg.FindOutgoingHTLCDeadline( - htlc.RHash, - ) + deadline := c.cfg.FindOutgoingHTLCDeadline(htlc) resolver.SupplementDeadline(deadline) htlcResolvers = append(htlcResolvers, resolver) @@ -2523,9 +2521,7 @@ func (c *ChannelArbitrator) prepContractResolutions( // supplement the resolver with the expiry // block height of its corresponding incoming // HTLC. - deadline := c.cfg.FindOutgoingHTLCDeadline( - htlc.RHash, - ) + deadline := c.cfg.FindOutgoingHTLCDeadline(htlc) resolver.SupplementDeadline(deadline) htlcResolvers = append(htlcResolvers, resolver) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index a68e9d6b3..77c9597c0 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -395,6 +396,11 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, Budget: *DefaultBudgetConfig(), PreimageDB: newMockWitnessBeacon(), Registry: &mockRegistry{}, + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, } // We'll use the resolvedChan to synchronize on call to @@ -430,7 +436,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, return &channeldb.OpenChannel{}, nil }, FindOutgoingHTLCDeadline: func( - rHash chainhash.Hash) fn.Option[int32] { + htlc channeldb.HTLC) fn.Option[int32] { return fn.None[int32]() }, diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 7d42d7be0..b3221f5c5 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" @@ -43,6 +44,11 @@ func newCommitSweepResolverTestContext(t *testing.T, Notifier: notifier, Sweeper: sweeper, Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 7bfe285bd..34a672706 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -352,6 +352,11 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver }, HtlcNotifier: htlcNotifier, Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index c788241d8..e4a3aaee0 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -153,6 +154,11 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { }, OnionProcessor: onionProcessor, Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index c8ba57d93..b0ee21f1e 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -12,6 +12,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -92,6 +93,11 @@ func newHtlcResolverTestContext(t *testing.T, }, HtlcNotifier: htlcNotifier, Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, report *channeldb.ResolverReport) error { diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 2d639a82f..12cf63886 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -304,6 +305,9 @@ func TestHtlcTimeoutResolver(t *testing.T) { return nil }, Budget: *DefaultBudgetConfig(), + QueryIncomingCircuit: func(circuit models.CircuitKey) *models.CircuitKey { + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { diff --git a/server.go b/server.go index d8f4850f3..e2c0b831a 100644 --- a/server.go +++ b/server.go @@ -1134,6 +1134,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, ) + //nolint:lll s.chainArb = contractcourt.NewChainArbitrator(contractcourt.ChainArbitratorConfig{ ChainHash: *s.cfg.ActiveNetParams.GenesisHash, IncomingBroadcastDelta: lncfg.DefaultIncomingBroadcastDelta, @@ -1224,10 +1225,26 @@ func newServer(cfg *Config, listenAddrs []net.Addr, PaymentsExpirationGracePeriod: cfg.PaymentsExpirationGracePeriod, IsForwardedHTLC: s.htlcSwitch.IsForwardedHTLC, Clock: clock.NewDefaultClock(), - SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete, //nolint:lll - PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome, //nolint: lll + SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete, + PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome, HtlcNotifier: s.htlcNotifier, Budget: *s.cfg.Sweeper.Budget, + + // TODO(yy): remove this hack once PaymentCircuit is interfaced. + QueryIncomingCircuit: func( + circuit models.CircuitKey) *models.CircuitKey { + + // Get the circuit map. + circuits := s.htlcSwitch.CircuitLookup() + + // Lookup the outgoing circuit. + pc := circuits.LookupOpenCircuit(circuit) + if pc == nil { + return nil + } + + return &pc.Incoming + }, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin.