multi: query circuit map inside contractcourt

This commit adds a new config method `QueryIncomingCircuit` that can be
used to query the payment's incoming circuit for giving its outgoing
circuit key.
This commit is contained in:
yyforyongyu 2024-04-04 11:10:03 +08:00
parent 4134b1c00a
commit 07466c4f8c
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
10 changed files with 130 additions and 44 deletions

View File

@ -3,7 +3,6 @@ package contractcourt
import (
"errors"
"fmt"
"math"
"sync"
"sync/atomic"
"time"
@ -14,6 +13,7 @@ import (
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
@ -206,6 +206,17 @@ type ChainArbitratorConfig struct {
// Budget is the configured budget for the arbitrator.
Budget BudgetConfig
// QueryIncomingCircuit is used to find the outgoing HTLC's
// corresponding incoming HTLC circuit. It queries the circuit map for
// a given outgoing circuit key and returns the incoming circuit key.
//
// TODO(yy): this is a hacky way to get around the cycling import issue
// as we cannot import `htlcswitch` here. A proper way is to define an
// interface here that asks for method `LookupOpenCircuit`,
// meanwhile, turn `PaymentCircuit` into an interface or bring it to a
// lower package.
QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey
}
// ChainArbitrator is a sub-system that oversees the on-chain resolution of all
@ -389,9 +400,11 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel,
return chanStateDB.FetchHistoricalChannel(&chanPoint)
},
FindOutgoingHTLCDeadline: func(
rHash chainhash.Hash) fn.Option[int32] {
htlc channeldb.HTLC) fn.Option[int32] {
return c.FindOutgoingHTLCDeadline(chanPoint, rHash)
return c.FindOutgoingHTLCDeadline(
channel.ShortChanID(), htlc,
)
},
}
@ -612,9 +625,11 @@ func (c *ChainArbitrator) Start() error {
return chanStateDB.FetchHistoricalChannel(&chanPoint)
},
FindOutgoingHTLCDeadline: func(
rHash chainhash.Hash) fn.Option[int32] {
htlc channeldb.HTLC) fn.Option[int32] {
return c.FindOutgoingHTLCDeadline(chanPoint, rHash)
return c.FindOutgoingHTLCDeadline(
closeChanInfo.ShortChanID, htlc,
)
},
}
chanLog, err := newBoltArbitratorLog(
@ -1224,22 +1239,48 @@ func (c *ChainArbitrator) SubscribeChannelEvents(
// by the timeout height of its corresponding incoming HTLC - this is the
// expiry height the that remote peer can spend his/her outgoing HTLC via the
// timeout path.
func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint,
rHash chainhash.Hash) fn.Option[int32] {
func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID,
outgoingHTLC channeldb.HTLC) fn.Option[int32] {
// minRefundTimeout tracks the minimal refund timeout found using the
// rHash. It's possible that we find multiple HTLCs living in different
// channels sharing the same rHash if an MPP is routed by us. In this
// case, we'll use the smallest refund timeout as the deadline.
//
// TODO(yy): can instead query the circuit map to find the exact HTLC.
minRefundTimeout := uint32(math.MaxInt32)
// Find the outgoing HTLC's corresponding incoming HTLC in the circuit
// map.
rHash := outgoingHTLC.RHash
circuit := models.CircuitKey{
ChanID: scid,
HtlcID: outgoingHTLC.HtlcIndex,
}
incomingCircuit := c.cfg.QueryIncomingCircuit(circuit)
// Iterate over all active channels to find the HTLC with the matching
// rHash.
// If there's no incoming circuit found, we will use the default
// deadline.
if incomingCircuit == nil {
log.Warnf("ChannelArbitrator(%v): incoming circuit key not "+
"found for rHash=%x, using default deadline instead",
scid, rHash)
return fn.None[int32]()
}
// If this is a locally initiated HTLC, it means we are the first hop.
// In this case, we can relax the deadline.
if incomingCircuit.ChanID.IsDefault() {
log.Infof("ChannelArbitrator(%v): using default deadline for "+
"locally initiated HTLC for rHash=%x", scid, rHash)
return fn.None[int32]()
}
log.Debugf("Found incoming circuit %v for rHash=%x using outgoing "+
"circuit %v", incomingCircuit, rHash, circuit)
c.Lock()
defer c.Unlock()
// Iterate over all active channels to find the incoming HTLC specified
// by its circuit key.
for cp, channelArb := range c.activeChannels {
// Skip the targeted channel as the incoming HTLC is not here.
if cp == chanPoint {
// Skip if the SCID doesn't match.
if channelArb.cfg.ShortChanID != incomingCircuit.ChanID {
continue
}
@ -1247,32 +1288,25 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint,
// HTLC.
for _, htlcs := range channelArb.activeHTLCs {
for _, htlc := range htlcs.incomingHTLCs {
if htlc.RHash != rHash {
// Skip if the index doesn't match.
if htlc.HtlcIndex != incomingCircuit.HtlcID {
continue
}
log.Debugf("ChannelArbitrator(%v): found "+
"incoming HTLC in channel=%v using "+
"rHash=%v, refundTimeout=%v", chanPoint,
"rHash=%x, refundTimeout=%v", scid,
cp, rHash, htlc.RefundTimeout)
// Update the value if it's smaller.
if minRefundTimeout > htlc.RefundTimeout {
minRefundTimeout = htlc.RefundTimeout
}
return fn.Some(int32(htlc.RefundTimeout))
}
}
}
// Return the refund timeout value if found.
if minRefundTimeout != math.MaxInt32 {
return fn.Some(int32(minRefundTimeout))
}
// If there's no incoming HTLC found, it means we are the first hop. In
// this case, we can relax the deadline.
log.Infof("ChannelArbitrator(%v): incoming HTLC not found for "+
"rHash=%v, using default deadline instead", chanPoint, rHash)
// If there's no incoming HTLC found, yet we have the incoming circuit,
// something is wrong - in this case, we return the none deadline.
log.Errorf("ChannelArbitrator(%v): incoming HTLC not found for "+
"rHash=%x, using default deadline instead", scid, rHash)
return fn.None[int32]()
}

View File

@ -8,6 +8,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
@ -172,6 +173,11 @@ func TestResolveContract(t *testing.T) {
},
Clock: clock.NewDefaultClock(),
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
}
chainArb := NewChainArbitrator(
chainArbCfg, db,

View File

@ -170,7 +170,7 @@ type ChannelArbitratorConfig struct {
// deadline is defined by the timeout height of its corresponding
// incoming HTLC - this is the expiry height the that remote peer can
// spend his/her outgoing HTLC via the timeout path.
FindOutgoingHTLCDeadline func(rHash chainhash.Hash) fn.Option[int32]
FindOutgoingHTLCDeadline func(htlc channeldb.HTLC) fn.Option[int32]
ChainArbitratorConfig
}
@ -764,7 +764,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
// the resolver with the expiry block height of its
// corresponding incoming HTLC.
if !htlc.Incoming {
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc.RHash)
deadline := c.cfg.FindOutgoingHTLCDeadline(*htlc)
htlcResolver.SupplementDeadline(deadline)
}
}
@ -2421,9 +2421,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
// supplement the resolver with the expiry
// block height of its corresponding incoming
// HTLC.
deadline := c.cfg.FindOutgoingHTLCDeadline(
htlc.RHash,
)
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
resolver.SupplementDeadline(deadline)
htlcResolvers = append(htlcResolvers, resolver)
@ -2523,9 +2521,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
// supplement the resolver with the expiry
// block height of its corresponding incoming
// HTLC.
deadline := c.cfg.FindOutgoingHTLCDeadline(
htlc.RHash,
)
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
resolver.SupplementDeadline(deadline)
htlcResolvers = append(htlcResolvers, resolver)

View File

@ -15,6 +15,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
@ -395,6 +396,11 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
Budget: *DefaultBudgetConfig(),
PreimageDB: newMockWitnessBeacon(),
Registry: &mockRegistry{},
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
}
// We'll use the resolvedChan to synchronize on call to
@ -430,7 +436,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
return &channeldb.OpenChannel{}, nil
},
FindOutgoingHTLCDeadline: func(
rHash chainhash.Hash) fn.Option[int32] {
htlc channeldb.HTLC) fn.Option[int32] {
return fn.None[int32]()
},

View File

@ -9,6 +9,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntest/mock"
@ -43,6 +44,11 @@ func newCommitSweepResolverTestContext(t *testing.T,
Notifier: notifier,
Sweeper: sweeper,
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {

View File

@ -352,6 +352,11 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver
},
HtlcNotifier: htlcNotifier,
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {

View File

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock"
@ -153,6 +154,11 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext {
},
OnionProcessor: onionProcessor,
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {

View File

@ -12,6 +12,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
@ -92,6 +93,11 @@ func newHtlcResolverTestContext(t *testing.T,
},
HtlcNotifier: htlcNotifier,
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
report *channeldb.ResolverReport) error {

View File

@ -14,6 +14,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
@ -304,6 +305,9 @@ func TestHtlcTimeoutResolver(t *testing.T) {
return nil
},
Budget: *DefaultBudgetConfig(),
QueryIncomingCircuit: func(circuit models.CircuitKey) *models.CircuitKey {
return nil
},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {

View File

@ -1134,6 +1134,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
},
)
//nolint:lll
s.chainArb = contractcourt.NewChainArbitrator(contractcourt.ChainArbitratorConfig{
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
IncomingBroadcastDelta: lncfg.DefaultIncomingBroadcastDelta,
@ -1224,10 +1225,26 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
PaymentsExpirationGracePeriod: cfg.PaymentsExpirationGracePeriod,
IsForwardedHTLC: s.htlcSwitch.IsForwardedHTLC,
Clock: clock.NewDefaultClock(),
SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete, //nolint:lll
PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome, //nolint: lll
SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete,
PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome,
HtlcNotifier: s.htlcNotifier,
Budget: *s.cfg.Sweeper.Budget,
// TODO(yy): remove this hack once PaymentCircuit is interfaced.
QueryIncomingCircuit: func(
circuit models.CircuitKey) *models.CircuitKey {
// Get the circuit map.
circuits := s.htlcSwitch.CircuitLookup()
// Lookup the outgoing circuit.
pc := circuits.LookupOpenCircuit(circuit)
if pc == nil {
return nil
}
return &pc.Incoming
},
}, dbs.ChanStateDB)
// Select the configuration and funding parameters for Bitcoin.