mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-02 11:09:38 +02:00
multi: query circuit map inside contractcourt
This commit adds a new config method `QueryIncomingCircuit` that can be used to query the payment's incoming circuit for giving its outgoing circuit key.
This commit is contained in:
parent
4134b1c00a
commit
07466c4f8c
@ -3,7 +3,6 @@ package contractcourt
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
@ -206,6 +206,17 @@ type ChainArbitratorConfig struct {
|
||||
|
||||
// Budget is the configured budget for the arbitrator.
|
||||
Budget BudgetConfig
|
||||
|
||||
// QueryIncomingCircuit is used to find the outgoing HTLC's
|
||||
// corresponding incoming HTLC circuit. It queries the circuit map for
|
||||
// a given outgoing circuit key and returns the incoming circuit key.
|
||||
//
|
||||
// TODO(yy): this is a hacky way to get around the cycling import issue
|
||||
// as we cannot import `htlcswitch` here. A proper way is to define an
|
||||
// interface here that asks for method `LookupOpenCircuit`,
|
||||
// meanwhile, turn `PaymentCircuit` into an interface or bring it to a
|
||||
// lower package.
|
||||
QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey
|
||||
}
|
||||
|
||||
// ChainArbitrator is a sub-system that oversees the on-chain resolution of all
|
||||
@ -389,9 +400,11 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel,
|
||||
return chanStateDB.FetchHistoricalChannel(&chanPoint)
|
||||
},
|
||||
FindOutgoingHTLCDeadline: func(
|
||||
rHash chainhash.Hash) fn.Option[int32] {
|
||||
htlc channeldb.HTLC) fn.Option[int32] {
|
||||
|
||||
return c.FindOutgoingHTLCDeadline(chanPoint, rHash)
|
||||
return c.FindOutgoingHTLCDeadline(
|
||||
channel.ShortChanID(), htlc,
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
@ -612,9 +625,11 @@ func (c *ChainArbitrator) Start() error {
|
||||
return chanStateDB.FetchHistoricalChannel(&chanPoint)
|
||||
},
|
||||
FindOutgoingHTLCDeadline: func(
|
||||
rHash chainhash.Hash) fn.Option[int32] {
|
||||
htlc channeldb.HTLC) fn.Option[int32] {
|
||||
|
||||
return c.FindOutgoingHTLCDeadline(chanPoint, rHash)
|
||||
return c.FindOutgoingHTLCDeadline(
|
||||
closeChanInfo.ShortChanID, htlc,
|
||||
)
|
||||
},
|
||||
}
|
||||
chanLog, err := newBoltArbitratorLog(
|
||||
@ -1224,22 +1239,48 @@ func (c *ChainArbitrator) SubscribeChannelEvents(
|
||||
// by the timeout height of its corresponding incoming HTLC - this is the
|
||||
// expiry height the that remote peer can spend his/her outgoing HTLC via the
|
||||
// timeout path.
|
||||
func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint,
|
||||
rHash chainhash.Hash) fn.Option[int32] {
|
||||
func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID,
|
||||
outgoingHTLC channeldb.HTLC) fn.Option[int32] {
|
||||
|
||||
// minRefundTimeout tracks the minimal refund timeout found using the
|
||||
// rHash. It's possible that we find multiple HTLCs living in different
|
||||
// channels sharing the same rHash if an MPP is routed by us. In this
|
||||
// case, we'll use the smallest refund timeout as the deadline.
|
||||
//
|
||||
// TODO(yy): can instead query the circuit map to find the exact HTLC.
|
||||
minRefundTimeout := uint32(math.MaxInt32)
|
||||
// Find the outgoing HTLC's corresponding incoming HTLC in the circuit
|
||||
// map.
|
||||
rHash := outgoingHTLC.RHash
|
||||
circuit := models.CircuitKey{
|
||||
ChanID: scid,
|
||||
HtlcID: outgoingHTLC.HtlcIndex,
|
||||
}
|
||||
incomingCircuit := c.cfg.QueryIncomingCircuit(circuit)
|
||||
|
||||
// Iterate over all active channels to find the HTLC with the matching
|
||||
// rHash.
|
||||
// If there's no incoming circuit found, we will use the default
|
||||
// deadline.
|
||||
if incomingCircuit == nil {
|
||||
log.Warnf("ChannelArbitrator(%v): incoming circuit key not "+
|
||||
"found for rHash=%x, using default deadline instead",
|
||||
scid, rHash)
|
||||
|
||||
return fn.None[int32]()
|
||||
}
|
||||
|
||||
// If this is a locally initiated HTLC, it means we are the first hop.
|
||||
// In this case, we can relax the deadline.
|
||||
if incomingCircuit.ChanID.IsDefault() {
|
||||
log.Infof("ChannelArbitrator(%v): using default deadline for "+
|
||||
"locally initiated HTLC for rHash=%x", scid, rHash)
|
||||
|
||||
return fn.None[int32]()
|
||||
}
|
||||
|
||||
log.Debugf("Found incoming circuit %v for rHash=%x using outgoing "+
|
||||
"circuit %v", incomingCircuit, rHash, circuit)
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Iterate over all active channels to find the incoming HTLC specified
|
||||
// by its circuit key.
|
||||
for cp, channelArb := range c.activeChannels {
|
||||
// Skip the targeted channel as the incoming HTLC is not here.
|
||||
if cp == chanPoint {
|
||||
// Skip if the SCID doesn't match.
|
||||
if channelArb.cfg.ShortChanID != incomingCircuit.ChanID {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -1247,32 +1288,25 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(chanPoint wire.OutPoint,
|
||||
// HTLC.
|
||||
for _, htlcs := range channelArb.activeHTLCs {
|
||||
for _, htlc := range htlcs.incomingHTLCs {
|
||||
if htlc.RHash != rHash {
|
||||
// Skip if the index doesn't match.
|
||||
if htlc.HtlcIndex != incomingCircuit.HtlcID {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("ChannelArbitrator(%v): found "+
|
||||
"incoming HTLC in channel=%v using "+
|
||||
"rHash=%v, refundTimeout=%v", chanPoint,
|
||||
"rHash=%x, refundTimeout=%v", scid,
|
||||
cp, rHash, htlc.RefundTimeout)
|
||||
|
||||
// Update the value if it's smaller.
|
||||
if minRefundTimeout > htlc.RefundTimeout {
|
||||
minRefundTimeout = htlc.RefundTimeout
|
||||
}
|
||||
return fn.Some(int32(htlc.RefundTimeout))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the refund timeout value if found.
|
||||
if minRefundTimeout != math.MaxInt32 {
|
||||
return fn.Some(int32(minRefundTimeout))
|
||||
}
|
||||
|
||||
// If there's no incoming HTLC found, it means we are the first hop. In
|
||||
// this case, we can relax the deadline.
|
||||
log.Infof("ChannelArbitrator(%v): incoming HTLC not found for "+
|
||||
"rHash=%v, using default deadline instead", chanPoint, rHash)
|
||||
// If there's no incoming HTLC found, yet we have the incoming circuit,
|
||||
// something is wrong - in this case, we return the none deadline.
|
||||
log.Errorf("ChannelArbitrator(%v): incoming HTLC not found for "+
|
||||
"rHash=%x, using default deadline instead", scid, rHash)
|
||||
|
||||
return fn.None[int32]()
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -172,6 +173,11 @@ func TestResolveContract(t *testing.T) {
|
||||
},
|
||||
Clock: clock.NewDefaultClock(),
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
chainArb := NewChainArbitrator(
|
||||
chainArbCfg, db,
|
||||
|
@ -170,7 +170,7 @@ type ChannelArbitratorConfig struct {
|
||||
// deadline is defined by the timeout height of its corresponding
|
||||
// incoming HTLC - this is the expiry height the that remote peer can
|
||||
// spend his/her outgoing HTLC via the timeout path.
|
||||
FindOutgoingHTLCDeadline func(rHash chainhash.Hash) fn.Option[int32]
|
||||
FindOutgoingHTLCDeadline func(htlc channeldb.HTLC) fn.Option[int32]
|
||||
|
||||
ChainArbitratorConfig
|
||||
}
|
||||
@ -764,7 +764,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
|
||||
// the resolver with the expiry block height of its
|
||||
// corresponding incoming HTLC.
|
||||
if !htlc.Incoming {
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc.RHash)
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(*htlc)
|
||||
htlcResolver.SupplementDeadline(deadline)
|
||||
}
|
||||
}
|
||||
@ -2421,9 +2421,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
|
||||
// supplement the resolver with the expiry
|
||||
// block height of its corresponding incoming
|
||||
// HTLC.
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(
|
||||
htlc.RHash,
|
||||
)
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
|
||||
resolver.SupplementDeadline(deadline)
|
||||
|
||||
htlcResolvers = append(htlcResolvers, resolver)
|
||||
@ -2523,9 +2521,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
|
||||
// supplement the resolver with the expiry
|
||||
// block height of its corresponding incoming
|
||||
// HTLC.
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(
|
||||
htlc.RHash,
|
||||
)
|
||||
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
|
||||
resolver.SupplementDeadline(deadline)
|
||||
|
||||
htlcResolvers = append(htlcResolvers, resolver)
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
@ -395,6 +396,11 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
PreimageDB: newMockWitnessBeacon(),
|
||||
Registry: &mockRegistry{},
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// We'll use the resolvedChan to synchronize on call to
|
||||
@ -430,7 +436,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog,
|
||||
return &channeldb.OpenChannel{}, nil
|
||||
},
|
||||
FindOutgoingHTLCDeadline: func(
|
||||
rHash chainhash.Hash) fn.Option[int32] {
|
||||
htlc channeldb.HTLC) fn.Option[int32] {
|
||||
|
||||
return fn.None[int32]()
|
||||
},
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||
@ -43,6 +44,11 @@ func newCommitSweepResolverTestContext(t *testing.T,
|
||||
Notifier: notifier,
|
||||
Sweeper: sweeper,
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
PutResolverReport: func(_ kvdb.RwTx,
|
||||
_ *channeldb.ResolverReport) error {
|
||||
|
@ -352,6 +352,11 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver
|
||||
},
|
||||
HtlcNotifier: htlcNotifier,
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
PutResolverReport: func(_ kvdb.RwTx,
|
||||
_ *channeldb.ResolverReport) error {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnmock"
|
||||
@ -153,6 +154,11 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext {
|
||||
},
|
||||
OnionProcessor: onionProcessor,
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
PutResolverReport: func(_ kvdb.RwTx,
|
||||
_ *channeldb.ResolverReport) error {
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
@ -92,6 +93,11 @@ func newHtlcResolverTestContext(t *testing.T,
|
||||
},
|
||||
HtlcNotifier: htlcNotifier,
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
PutResolverReport: func(_ kvdb.RwTx,
|
||||
report *channeldb.ResolverReport) error {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
@ -304,6 +305,9 @@ func TestHtlcTimeoutResolver(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
Budget: *DefaultBudgetConfig(),
|
||||
QueryIncomingCircuit: func(circuit models.CircuitKey) *models.CircuitKey {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
PutResolverReport: func(_ kvdb.RwTx,
|
||||
_ *channeldb.ResolverReport) error {
|
||||
|
21
server.go
21
server.go
@ -1134,6 +1134,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
},
|
||||
)
|
||||
|
||||
//nolint:lll
|
||||
s.chainArb = contractcourt.NewChainArbitrator(contractcourt.ChainArbitratorConfig{
|
||||
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
|
||||
IncomingBroadcastDelta: lncfg.DefaultIncomingBroadcastDelta,
|
||||
@ -1224,10 +1225,26 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
PaymentsExpirationGracePeriod: cfg.PaymentsExpirationGracePeriod,
|
||||
IsForwardedHTLC: s.htlcSwitch.IsForwardedHTLC,
|
||||
Clock: clock.NewDefaultClock(),
|
||||
SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete, //nolint:lll
|
||||
PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome, //nolint: lll
|
||||
SubscribeBreachComplete: s.breachArbitrator.SubscribeBreachComplete,
|
||||
PutFinalHtlcOutcome: s.chanStateDB.PutOnchainFinalHtlcOutcome,
|
||||
HtlcNotifier: s.htlcNotifier,
|
||||
Budget: *s.cfg.Sweeper.Budget,
|
||||
|
||||
// TODO(yy): remove this hack once PaymentCircuit is interfaced.
|
||||
QueryIncomingCircuit: func(
|
||||
circuit models.CircuitKey) *models.CircuitKey {
|
||||
|
||||
// Get the circuit map.
|
||||
circuits := s.htlcSwitch.CircuitLookup()
|
||||
|
||||
// Lookup the outgoing circuit.
|
||||
pc := circuits.LookupOpenCircuit(circuit)
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pc.Incoming
|
||||
},
|
||||
}, dbs.ChanStateDB)
|
||||
|
||||
// Select the configuration and funding parameters for Bitcoin.
|
||||
|
Loading…
x
Reference in New Issue
Block a user