mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-04-08 20:28:04 +02:00
Merge pull request #1166 from cfromknecht/switch-isolate-pending-channels
Isolate Pending Channels in Switch
This commit is contained in:
commit
444a6c08cc
@ -1437,7 +1437,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
|
||||
RemoteChanCfg: bobCfg,
|
||||
IdentityPub: aliceKeyPub,
|
||||
FundingOutpoint: *prevOut,
|
||||
ShortChanID: shortChanID,
|
||||
ShortChannelID: shortChanID,
|
||||
ChanType: channeldb.SingleFunder,
|
||||
IsInitiator: true,
|
||||
Capacity: channelCapacity,
|
||||
@ -1455,7 +1455,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
|
||||
RemoteChanCfg: aliceCfg,
|
||||
IdentityPub: bobKeyPub,
|
||||
FundingOutpoint: *prevOut,
|
||||
ShortChanID: shortChanID,
|
||||
ShortChannelID: shortChanID,
|
||||
ChanType: channeldb.SingleFunder,
|
||||
IsInitiator: false,
|
||||
Capacity: channelCapacity,
|
||||
|
@ -340,10 +340,10 @@ type OpenChannel struct {
|
||||
// target blockchain as specified by the chain hash parameter.
|
||||
FundingOutpoint wire.OutPoint
|
||||
|
||||
// ShortChanID encodes the exact location in the chain in which the
|
||||
// ShortChannelID encodes the exact location in the chain in which the
|
||||
// channel was initially confirmed. This includes: the block height,
|
||||
// transaction index, and the output within the target transaction.
|
||||
ShortChanID lnwire.ShortChannelID
|
||||
ShortChannelID lnwire.ShortChannelID
|
||||
|
||||
// IsPending indicates whether a channel's funding transaction has been
|
||||
// confirmed.
|
||||
@ -460,6 +460,47 @@ func (c *OpenChannel) FullSync() error {
|
||||
return c.Db.Update(c.fullSync)
|
||||
}
|
||||
|
||||
// ShortChanID returns the current ShortChannelID of this channel.
|
||||
func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
return c.ShortChannelID
|
||||
}
|
||||
|
||||
// RefreshShortChanID updates the in-memory short channel ID using the latest
|
||||
// value observed on disk.
|
||||
func (c *OpenChannel) RefreshShortChanID() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
var sid lnwire.ShortChannelID
|
||||
err := c.Db.View(func(tx *bolt.Tx) error {
|
||||
chanBucket, err := readChanBucket(
|
||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sid = channel.ShortChannelID
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.ShortChannelID = sid
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChanBucket is a helper function that returns a writable bucket that a
|
||||
// channel's data resides in given: the public key for the node, the outpoint,
|
||||
// and the chainhash that the channel resides on.
|
||||
@ -582,7 +623,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
|
||||
}
|
||||
|
||||
channel.IsPending = false
|
||||
channel.ShortChanID = openLoc
|
||||
channel.ShortChannelID = openLoc
|
||||
|
||||
return putOpenChannel(chanBucket, channel)
|
||||
}); err != nil {
|
||||
@ -590,7 +631,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
|
||||
}
|
||||
|
||||
c.IsPending = false
|
||||
c.ShortChanID = openLoc
|
||||
c.ShortChannelID = openLoc
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -692,7 +733,7 @@ func fetchOpenChannel(chanBucket *bolt.Bucket,
|
||||
return nil, fmt.Errorf("unable to fetch chan revocations: %v", err)
|
||||
}
|
||||
|
||||
channel.Packager = NewChannelPackager(channel.ShortChanID)
|
||||
channel.Packager = NewChannelPackager(channel.ShortChannelID)
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
@ -1977,7 +2018,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
|
||||
var w bytes.Buffer
|
||||
if err := writeElements(&w,
|
||||
channel.ChanType, channel.ChainHash, channel.FundingOutpoint,
|
||||
channel.ShortChanID, channel.IsPending, channel.IsInitiator,
|
||||
channel.ShortChannelID, channel.IsPending, channel.IsInitiator,
|
||||
channel.ChanStatus, channel.FundingBroadcastHeight,
|
||||
channel.NumConfsRequired, channel.ChannelFlags,
|
||||
channel.IdentityPub, channel.Capacity, channel.TotalMSatSent,
|
||||
@ -2087,7 +2128,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
|
||||
|
||||
if err := readElements(r,
|
||||
&channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint,
|
||||
&channel.ShortChanID, &channel.IsPending, &channel.IsInitiator,
|
||||
&channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator,
|
||||
&channel.ChanStatus, &channel.FundingBroadcastHeight,
|
||||
&channel.NumConfsRequired, &channel.ChannelFlags,
|
||||
&channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent,
|
||||
@ -2110,7 +2151,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
|
||||
return err
|
||||
}
|
||||
|
||||
channel.Packager = NewChannelPackager(channel.ShortChanID)
|
||||
channel.Packager = NewChannelPackager(channel.ShortChannelID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -186,7 +186,7 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) {
|
||||
ChanType: SingleFunder,
|
||||
ChainHash: key,
|
||||
FundingOutpoint: *testOutpoint,
|
||||
ShortChanID: chanID,
|
||||
ShortChannelID: chanID,
|
||||
IsInitiator: true,
|
||||
IsPending: true,
|
||||
IdentityPub: pubKey,
|
||||
@ -514,7 +514,7 @@ func TestChannelStateTransition(t *testing.T) {
|
||||
}
|
||||
channel.RemoteNextRevocation = newPriv.PubKey()
|
||||
|
||||
fwdPkg := NewFwdPkg(channel.ShortChanID, oldRemoteCommit.CommitHeight,
|
||||
fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
|
||||
diskCommitDiff.LogUpdates, nil)
|
||||
|
||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
||||
@ -563,7 +563,7 @@ func TestChannelStateTransition(t *testing.T) {
|
||||
t.Fatalf("unable to add to commit chain: %v", err)
|
||||
}
|
||||
|
||||
fwdPkg = NewFwdPkg(channel.ShortChanID, oldRemoteCommit.CommitHeight, nil, nil)
|
||||
fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
|
||||
|
||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
||||
if err != nil {
|
||||
@ -688,9 +688,9 @@ func TestFetchPendingChannels(t *testing.T) {
|
||||
t.Fatalf("channel marked open should no longer be pending")
|
||||
}
|
||||
|
||||
if pendingChannels[0].ShortChanID != chanOpenLoc {
|
||||
if pendingChannels[0].ShortChanID() != chanOpenLoc {
|
||||
t.Fatalf("channel opening height not updated: expected %v, "+
|
||||
"got %v", spew.Sdump(pendingChannels[0].ShortChanID),
|
||||
"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
|
||||
chanOpenLoc)
|
||||
}
|
||||
|
||||
@ -700,9 +700,9 @@ func TestFetchPendingChannels(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch channels: %v", err)
|
||||
}
|
||||
if openChans[0].ShortChanID != chanOpenLoc {
|
||||
if openChans[0].ShortChanID() != chanOpenLoc {
|
||||
t.Fatalf("channel opening heights don't match: expected %v, "+
|
||||
"got %v", spew.Sdump(openChans[0].ShortChanID),
|
||||
"got %v", spew.Sdump(openChans[0].ShortChanID()),
|
||||
chanOpenLoc)
|
||||
}
|
||||
if openChans[0].FundingBroadcastHeight != broadcastHeight {
|
||||
@ -830,3 +830,85 @@ func TestFetchClosedChannels(t *testing.T) {
|
||||
"got %v", 0, len(closed))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
|
||||
// short channel ID of another OpenChannel to reflect a preceding call to
|
||||
// MarkOpen on a different OpenChannel.
|
||||
func TestRefreshShortChanID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cdb, cleanUp, err := makeTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// First create a test channel.
|
||||
state, err := createTestChannelState(cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create channel state: %v", err)
|
||||
}
|
||||
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 18555,
|
||||
}
|
||||
|
||||
// Mark the channel as pending within the channeldb.
|
||||
const broadcastHeight = 99
|
||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
||||
}
|
||||
|
||||
// Next, locate the pending channel with the database.
|
||||
pendingChannels, err := cdb.FetchPendingChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load pending channels; %v", err)
|
||||
}
|
||||
|
||||
var pendingChannel *OpenChannel
|
||||
for _, channel := range pendingChannels {
|
||||
if channel.FundingOutpoint == state.FundingOutpoint {
|
||||
pendingChannel = channel
|
||||
break
|
||||
}
|
||||
}
|
||||
if pendingChannel == nil {
|
||||
t.Fatalf("unable to find pending channel with funding "+
|
||||
"outpoint=%v: %v", state.FundingOutpoint, err)
|
||||
}
|
||||
|
||||
// Next, simulate the confirmation of the channel by marking it as
|
||||
// pending within the database.
|
||||
chanOpenLoc := lnwire.ShortChannelID{
|
||||
BlockHeight: 105,
|
||||
TxIndex: 10,
|
||||
TxPosition: 15,
|
||||
}
|
||||
|
||||
err = state.MarkAsOpen(chanOpenLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to mark channel open: %v", err)
|
||||
}
|
||||
|
||||
// The short_chan_id of the receiver to MarkAsOpen should reflect the
|
||||
// open location, but the other pending channel should remain unchanged.
|
||||
if state.ShortChanID() == pendingChannel.ShortChanID() {
|
||||
t.Fatalf("pending channel short_chan_ID should not have been " +
|
||||
"updated before refreshing short_chan_id")
|
||||
}
|
||||
|
||||
// Now, refresh the short channel ID of the pending channel.
|
||||
err = pendingChannel.RefreshShortChanID()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to refresh short_chan_id: %v", err)
|
||||
}
|
||||
|
||||
// This should result in both OpenChannel's now having the same
|
||||
// ShortChanID.
|
||||
if state.ShortChanID() != pendingChannel.ShortChanID() {
|
||||
t.Fatalf("expected pending channel short_chan_id to be "+
|
||||
"refreshed: want %v, got %v", state.ShortChanID(),
|
||||
pendingChannel.ShortChanID())
|
||||
}
|
||||
}
|
||||
|
@ -201,7 +201,7 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel,
|
||||
// all interfaces and methods the arbitrator needs to do its job.
|
||||
arbCfg := ChannelArbitratorConfig{
|
||||
ChanPoint: chanPoint,
|
||||
ShortChanID: channel.ShortChanID,
|
||||
ShortChanID: channel.ShortChanID(),
|
||||
BlockEpochs: blockEpoch,
|
||||
ForceCloseChan: func() (*lnwallet.LocalForceCloseSummary, error) {
|
||||
// With the channels fetched, attempt to locate
|
||||
|
@ -180,7 +180,7 @@ func (c *chainWatcher) Start() error {
|
||||
// As a height hint, we'll try to use the opening height, but if the
|
||||
// channel isn't yet open, then we'll use the height it was broadcast
|
||||
// at.
|
||||
heightHint := chanState.ShortChanID.BlockHeight
|
||||
heightHint := c.cfg.chanState.ShortChanID().BlockHeight
|
||||
if heightHint == 0 {
|
||||
heightHint = chanState.FundingBroadcastHeight
|
||||
}
|
||||
@ -472,7 +472,7 @@ func (c *chainWatcher) dispatchCooperativeClose(commitSpend *chainntnfs.SpendDet
|
||||
CloseHeight: uint32(commitSpend.SpendingHeight),
|
||||
SettledBalance: localAmt,
|
||||
CloseType: channeldb.CooperativeClose,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID(),
|
||||
IsPending: true,
|
||||
}
|
||||
err := c.cfg.chanState.CloseChannel(closeSummary)
|
||||
@ -564,7 +564,7 @@ func (c *chainWatcher) dispatchLocalForceClose(
|
||||
Capacity: chanSnapshot.Capacity,
|
||||
CloseType: channeldb.LocalForceClose,
|
||||
IsPending: true,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID(),
|
||||
CloseHeight: uint32(commitSpend.SpendingHeight),
|
||||
}
|
||||
|
||||
@ -739,7 +739,7 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail
|
||||
SettledBalance: settledBalance,
|
||||
CloseType: channeldb.BreachClose,
|
||||
IsPending: true,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID,
|
||||
ShortChanID: c.cfg.chanState.ShortChanID(),
|
||||
}
|
||||
|
||||
if err := c.cfg.chanState.CloseChannel(&closeSummary); err != nil {
|
||||
|
@ -304,7 +304,7 @@ type fundingConfig struct {
|
||||
// ReportShortChanID allows the funding manager to report the newly
|
||||
// discovered short channel ID of a formerly pending channel to outside
|
||||
// sub-systems.
|
||||
ReportShortChanID func(wire.OutPoint, lnwire.ShortChannelID) error
|
||||
ReportShortChanID func(wire.OutPoint) error
|
||||
|
||||
// ZombieSweeperInterval is the periodic time interval in which the
|
||||
// zombie sweeper is run.
|
||||
@ -1814,8 +1814,9 @@ func (f *fundingManager) waitForFundingConfirmation(completeChan *channeldb.Open
|
||||
}
|
||||
|
||||
// As there might already be an active link in the switch with an
|
||||
// outdated short chan ID, we'll update it now.
|
||||
err = f.cfg.ReportShortChanID(fundingPoint, shortChanID)
|
||||
// outdated short chan ID, we'll instruct the switch to load the updated
|
||||
// short chan id from disk.
|
||||
err = f.cfg.ReportShortChanID(fundingPoint)
|
||||
if err != nil {
|
||||
fndgLog.Errorf("unable to report short chan id: %v", err)
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey,
|
||||
WatchNewChannel: func(*channeldb.OpenChannel, *lnwire.NetAddress) error {
|
||||
return nil
|
||||
},
|
||||
ReportShortChanID: func(wire.OutPoint, lnwire.ShortChannelID) error {
|
||||
ReportShortChanID: func(wire.OutPoint) error {
|
||||
return nil
|
||||
},
|
||||
ZombieSweeperInterval: 1 * time.Hour,
|
||||
|
@ -364,7 +364,7 @@ func (cm *circuitMap) trimAllOpenCircuits() error {
|
||||
// First, skip any channels that have not been assigned their
|
||||
// final channel identifier, otherwise we would try to trim
|
||||
// htlcs belonging to the all-zero, sourceHop ID.
|
||||
chanID := activeChannel.ShortChanID
|
||||
chanID := activeChannel.ShortChanID()
|
||||
if chanID == sourceHop {
|
||||
continue
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ type ChannelLink interface {
|
||||
// be required in the event that a link is created before the short
|
||||
// chan ID for it is known, or a re-org occurs, and the funding
|
||||
// transaction changes location within the chain.
|
||||
UpdateShortChanID(lnwire.ShortChannelID)
|
||||
UpdateShortChanID() (lnwire.ShortChannelID, error)
|
||||
|
||||
// UpdateForwardingPolicy updates the forwarding policy for the target
|
||||
// ChannelLink. Once updated, the link will use the new forwarding
|
||||
|
@ -32,9 +32,14 @@ const (
|
||||
expiryGraceDelta = 2
|
||||
)
|
||||
|
||||
// ErrInternalLinkFailure is a generic error returned to the remote party so as
|
||||
// to obfuscate the true failure.
|
||||
var ErrInternalLinkFailure = errors.New("internal link failure")
|
||||
var (
|
||||
// ErrInternalLinkFailure is a generic error returned to the remote
|
||||
// party so as to obfuscate the true failure.
|
||||
ErrInternalLinkFailure = errors.New("internal link failure")
|
||||
|
||||
// ErrLinkShuttingDown signals that the link is shutting down.
|
||||
ErrLinkShuttingDown = errors.New("link shutting down")
|
||||
)
|
||||
|
||||
// ForwardingPolicy describes the set of constraints that a given ChannelLink
|
||||
// is to adhere to when forwarding HTLC's. For each incoming HTLC, this set of
|
||||
@ -444,9 +449,11 @@ func (l *channelLink) Stop() {
|
||||
// EligibleToForward returns a bool indicating if the channel is able to
|
||||
// actively accept requests to forward HTLC's. We're able to forward HTLC's if
|
||||
// we know the remote party's next revocation point. Otherwise, we can't
|
||||
// initiate new channel state.
|
||||
// initiate new channel state. We also require that the short channel ID not be
|
||||
// the all-zero source ID, meaning that the channel has had its ID finalized.
|
||||
func (l *channelLink) EligibleToForward() bool {
|
||||
return l.channel.RemoteNextRevocation() != nil
|
||||
return l.channel.RemoteNextRevocation() != nil &&
|
||||
l.ShortChanID() != sourceHop
|
||||
}
|
||||
|
||||
// sampleNetworkFee samples the current fee rate on the network to get into the
|
||||
@ -603,7 +610,7 @@ func (l *channelLink) syncChanStates() error {
|
||||
}
|
||||
|
||||
case <-l.quit:
|
||||
return fmt.Errorf("shutting down")
|
||||
return ErrLinkShuttingDown
|
||||
|
||||
case <-chanSyncDeadline:
|
||||
return fmt.Errorf("didn't receive ChannelReestablish before " +
|
||||
@ -759,9 +766,12 @@ func (l *channelLink) htlcManager() {
|
||||
// re-synchronize state with the remote peer. settledHtlcs is a map of
|
||||
// HTLC's that we re-settled as part of the channel state sync.
|
||||
if l.cfg.SyncStates {
|
||||
if err := l.syncChanStates(); err != nil {
|
||||
err := l.syncChanStates()
|
||||
if err != nil {
|
||||
l.errorf("unable to synchronize channel states: %v", err)
|
||||
l.fail(err.Error())
|
||||
if err != ErrLinkShuttingDown {
|
||||
l.fail(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -1540,18 +1550,31 @@ func (l *channelLink) ShortChanID() lnwire.ShortChannelID {
|
||||
// within the chain.
|
||||
//
|
||||
// NOTE: Part of the ChannelLink interface.
|
||||
func (l *channelLink) UpdateShortChanID(sid lnwire.ShortChannelID) {
|
||||
func (l *channelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
|
||||
chanID := l.ChanID()
|
||||
|
||||
// Refresh the channel state's short channel ID by loading it from disk.
|
||||
// This ensures that the channel state accurately reflects the updated
|
||||
// short channel ID.
|
||||
err := l.channel.State().RefreshShortChanID()
|
||||
if err != nil {
|
||||
l.errorf("unable to refresh short_chan_id for chan_id=%v: %v",
|
||||
chanID, err)
|
||||
return sourceHop, err
|
||||
}
|
||||
|
||||
sid := l.channel.ShortChanID()
|
||||
|
||||
l.infof("Updating to short_chan_id=%v for chan_id=%v", sid, chanID)
|
||||
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
log.Infof("Updating short chan ID for ChannelPoint(%v)", l)
|
||||
|
||||
l.shortChanID = sid
|
||||
l.Unlock()
|
||||
|
||||
go func() {
|
||||
err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{
|
||||
HtlcUpdates: l.htlcUpdates,
|
||||
ShortChanID: l.channel.ShortChanID(),
|
||||
ShortChanID: sid,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("Unable to update signals for "+
|
||||
@ -1559,7 +1582,7 @@ func (l *channelLink) UpdateShortChanID(sid lnwire.ShortChannelID) {
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
// ChanID returns the channel ID for the channel link. The channel ID is a more
|
||||
|
@ -402,3 +402,174 @@ func (m *memoryMailBox) MessageOutBox() chan lnwire.Message {
|
||||
func (m *memoryMailBox) PacketOutBox() chan *htlcPacket {
|
||||
return m.pktOutbox
|
||||
}
|
||||
|
||||
// mailOrchestrator is responsible for coordinating the creation and lifecycle
|
||||
// of mailboxes used within the switch. It supports the ability to create
|
||||
// mailboxes, reassign their short channel id's, deliver htlc packets, and
|
||||
// queue packets for mailboxes that have not been created due to a link's late
|
||||
// registration.
|
||||
type mailOrchestrator struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// mailboxes caches exactly one mailbox for all known channels.
|
||||
mailboxes map[lnwire.ChannelID]MailBox
|
||||
|
||||
// liveIndex maps a live short chan id to the primary mailbox key.
|
||||
// An index in liveIndex map is only entered under two conditions:
|
||||
// 1. A link has a non-zero short channel id at time of AddLink.
|
||||
// 2. A link receives a non-zero short channel via UpdateShortChanID.
|
||||
liveIndex map[lnwire.ShortChannelID]lnwire.ChannelID
|
||||
|
||||
// TODO(conner): add another pair of indexes:
|
||||
// chan_id -> short_chan_id
|
||||
// short_chan_id -> mailbox
|
||||
// so that Deliver can lookup mailbox directly once live,
|
||||
// but still queriable by channel_id.
|
||||
|
||||
// unclaimedPackets maps a live short chan id to queue of packets if no
|
||||
// mailbox has been created.
|
||||
unclaimedPackets map[lnwire.ShortChannelID][]*htlcPacket
|
||||
}
|
||||
|
||||
// newMailOrchestrator initializes a fresh mailOrchestrator.
|
||||
func newMailOrchestrator() *mailOrchestrator {
|
||||
return &mailOrchestrator{
|
||||
mailboxes: make(map[lnwire.ChannelID]MailBox),
|
||||
liveIndex: make(map[lnwire.ShortChannelID]lnwire.ChannelID),
|
||||
unclaimedPackets: make(map[lnwire.ShortChannelID][]*htlcPacket),
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instructs the orchestrator to stop all active mailboxes.
|
||||
func (mo *mailOrchestrator) Stop() {
|
||||
for _, mailbox := range mo.mailboxes {
|
||||
mailbox.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreateMailBox returns an existing mailbox belonging to `chanID`, or
|
||||
// creates and returns a new mailbox if none is found.
|
||||
func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID) MailBox {
|
||||
// First, try lookup the mailbox directly using only the shared mutex.
|
||||
mo.mu.RLock()
|
||||
mailbox, ok := mo.mailboxes[chanID]
|
||||
if ok {
|
||||
mo.mu.RUnlock()
|
||||
return mailbox
|
||||
}
|
||||
mo.mu.RUnlock()
|
||||
|
||||
// Otherwise, we will try again with exclusive lock, creating a mailbox
|
||||
// if one still has not been created.
|
||||
mo.mu.Lock()
|
||||
mailbox = mo.exclusiveGetOrCreateMailBox(chanID)
|
||||
mo.mu.Unlock()
|
||||
|
||||
return mailbox
|
||||
}
|
||||
|
||||
// exclusiveGetOrCreateMailBox checks for the existence of a mailbox for the
|
||||
// given channel id. If none is found, a new one is creates, started, and
|
||||
// recorded.
|
||||
//
|
||||
// NOTE: This method MUST be invoked with the mailOrchestrator's exclusive lock.
|
||||
func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox(
|
||||
chanID lnwire.ChannelID) MailBox {
|
||||
|
||||
mailbox, ok := mo.mailboxes[chanID]
|
||||
if !ok {
|
||||
mailbox = newMemoryMailBox()
|
||||
mailbox.Start()
|
||||
mo.mailboxes[chanID] = mailbox
|
||||
}
|
||||
|
||||
return mailbox
|
||||
}
|
||||
|
||||
// BindLiveShortChanID registers that messages bound for a particular short
|
||||
// channel id should be forwarded to the mailbox corresponding to the given
|
||||
// channel id. This method also checks to see if there are any unclaimed
|
||||
// packets for this short_chan_id. If any are found, they are delivered to the
|
||||
// mailbox and removed (marked as claimed).
|
||||
func (mo *mailOrchestrator) BindLiveShortChanID(mailbox MailBox,
|
||||
cid lnwire.ChannelID, sid lnwire.ShortChannelID) {
|
||||
|
||||
mo.mu.Lock()
|
||||
// Update the mapping from short channel id to mailbox's channel id.
|
||||
mo.liveIndex[sid] = cid
|
||||
|
||||
// Retrieve any unclaimed packets destined for this mailbox.
|
||||
pkts := mo.unclaimedPackets[sid]
|
||||
delete(mo.unclaimedPackets, sid)
|
||||
mo.mu.Unlock()
|
||||
|
||||
// Deliver the unclaimed packets.
|
||||
for _, pkt := range pkts {
|
||||
mailbox.AddPacket(pkt)
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver lookups the target mailbox using the live index from short_chan_id
|
||||
// to channel_id. If the mailbox is found, the message is delivered directly.
|
||||
// Otherwise the packet is recorded as unclaimed, and will be delivered to the
|
||||
// mailbox upon the subsequent call to BindLiveShortChanID.
|
||||
func (mo *mailOrchestrator) Deliver(
|
||||
sid lnwire.ShortChannelID, pkt *htlcPacket) error {
|
||||
|
||||
var (
|
||||
mailbox MailBox
|
||||
found bool
|
||||
)
|
||||
|
||||
// First, try to find the channel id for the target short_chan_id. If
|
||||
// the link is live, we will also look up the created mailbox.
|
||||
mo.mu.RLock()
|
||||
chanID, isLive := mo.liveIndex[sid]
|
||||
if isLive {
|
||||
mailbox, found = mo.mailboxes[chanID]
|
||||
}
|
||||
mo.mu.RUnlock()
|
||||
|
||||
// The link is live and target mailbox was found, deliver immediately.
|
||||
if isLive && found {
|
||||
return mailbox.AddPacket(pkt)
|
||||
}
|
||||
|
||||
// If we detected that the link has not been made live, we will acquire
|
||||
// the exclusive lock preemptively in order to queue this packet in the
|
||||
// list of unclaimed packets.
|
||||
mo.mu.Lock()
|
||||
|
||||
// Double check to see if the mailbox has been not made live since the
|
||||
// release of the shared lock.
|
||||
//
|
||||
// NOTE: Checking again with the exclusive lock held prevents a race
|
||||
// condition where BindLiveShortChanID is interleaved between the
|
||||
// release of the shared lock, and acquiring the exclusive lock. The
|
||||
// result would be stuck packets, as they wouldn't be redelivered until
|
||||
// the next call to BindLiveShortChanID, which is expected to occur
|
||||
// infrequently.
|
||||
chanID, isLive = mo.liveIndex[sid]
|
||||
if isLive {
|
||||
// Reaching this point indicates the mailbox is actually live.
|
||||
// We'll try to load the mailbox using the fresh channel id.
|
||||
//
|
||||
// NOTE: This should never create a new mailbox, as the live
|
||||
// index should only be set if the mailbox had been initialized
|
||||
// beforehand. However, this does ensure that this case is
|
||||
// handled properly in the event that it could happen.
|
||||
mailbox = mo.exclusiveGetOrCreateMailBox(chanID)
|
||||
mo.mu.Unlock()
|
||||
|
||||
// Deliver the packet to the mailbox if it was found or created.
|
||||
return mailbox.AddPacket(pkt)
|
||||
}
|
||||
|
||||
// Finally, if the channel id is still not found in the live index,
|
||||
// we'll add this to the list of unclaimed packets. These will be
|
||||
// delivered upon the next call to BindLiveShortChanID.
|
||||
mo.unclaimedPackets[sid] = append(mo.unclaimedPackets[sid], pkt)
|
||||
mo.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -147,3 +147,121 @@ func TestMailBoxCouriers(t *testing.T) {
|
||||
spew.Sdump(sentPackets), spew.Sdump(recvdPackets))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMailOrchestrator asserts that the orchestrator properly buffers packets
|
||||
// for channels that haven't been made live, such that they are delivered
|
||||
// immediately after BindLiveShortChanID. It also tests that packets are delivered
|
||||
// readily to mailboxes for channels that are already in the live state.
|
||||
func TestMailOrchestrator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll create a new instance of our orchestrator.
|
||||
mo := newMailOrchestrator()
|
||||
defer mo.Stop()
|
||||
|
||||
// We'll be delivering 10 htlc packets via the orchestrator.
|
||||
const numPackets = 10
|
||||
const halfPackets = numPackets / 2
|
||||
|
||||
// Before any mailbox is created or made live, we will deliver half of
|
||||
// the htlcs via the orchestrator.
|
||||
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
|
||||
sentPackets := make([]*htlcPacket, halfPackets)
|
||||
for i := 0; i < halfPackets; i++ {
|
||||
pkt := &htlcPacket{
|
||||
outgoingChanID: aliceChanID,
|
||||
outgoingHTLCID: uint64(i),
|
||||
incomingChanID: bobChanID,
|
||||
incomingHTLCID: uint64(i),
|
||||
amount: lnwire.MilliSatoshi(prand.Int63()),
|
||||
}
|
||||
sentPackets[i] = pkt
|
||||
|
||||
mo.Deliver(pkt.outgoingChanID, pkt)
|
||||
}
|
||||
|
||||
// Now, initialize a new mailbox for Alice's chanid.
|
||||
mailbox := mo.GetOrCreateMailBox(chanID1)
|
||||
|
||||
// Verify that no messages are received, since Alice's mailbox has not
|
||||
// been made live.
|
||||
for i := 0; i < halfPackets; i++ {
|
||||
timeout := time.After(50 * time.Millisecond)
|
||||
select {
|
||||
case <-mailbox.MessageOutBox():
|
||||
t.Fatalf("should not receive wire msg after reset")
|
||||
case <-timeout:
|
||||
}
|
||||
}
|
||||
|
||||
// Assign a short chan id to the existing mailbox, make it available for
|
||||
// capturing incoming HTLCs. The HTLCs added above should be delivered
|
||||
// immediately.
|
||||
mo.BindLiveShortChanID(mailbox, chanID1, aliceChanID)
|
||||
|
||||
// Verify that all of the packets are queued and delivered to Alice's
|
||||
// mailbox.
|
||||
recvdPackets := make([]*htlcPacket, 0, len(sentPackets))
|
||||
for i := 0; i < halfPackets; i++ {
|
||||
timeout := time.After(5 * time.Second)
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("didn't recv pkt %d after timeout", i)
|
||||
case pkt := <-mailbox.PacketOutBox():
|
||||
recvdPackets = append(recvdPackets, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
// We should have received half of the total number of packets.
|
||||
if len(recvdPackets) != halfPackets {
|
||||
t.Fatalf("expected %v packets instead got %v",
|
||||
halfPackets, len(recvdPackets))
|
||||
}
|
||||
|
||||
// Check that the received packets are equal to the sent packets.
|
||||
if !reflect.DeepEqual(recvdPackets, sentPackets) {
|
||||
t.Fatalf("recvd packets mismatched: expected %v, got %v",
|
||||
spew.Sdump(sentPackets), spew.Sdump(recvdPackets))
|
||||
}
|
||||
|
||||
// For the second half of the test, create a new mailbox for Bob and
|
||||
// immediately make it live with an assigned short chan id.
|
||||
mailbox = mo.GetOrCreateMailBox(chanID2)
|
||||
mo.BindLiveShortChanID(mailbox, chanID2, bobChanID)
|
||||
|
||||
// Create the second half of our htlcs, and deliver them via the
|
||||
// orchestrator. We should be able to receive each of these in order.
|
||||
recvdPackets = make([]*htlcPacket, 0, len(sentPackets))
|
||||
for i := 0; i < halfPackets; i++ {
|
||||
pkt := &htlcPacket{
|
||||
outgoingChanID: aliceChanID,
|
||||
outgoingHTLCID: uint64(halfPackets + i),
|
||||
incomingChanID: bobChanID,
|
||||
incomingHTLCID: uint64(halfPackets + i),
|
||||
amount: lnwire.MilliSatoshi(prand.Int63()),
|
||||
}
|
||||
sentPackets[i] = pkt
|
||||
|
||||
mo.Deliver(pkt.incomingChanID, pkt)
|
||||
|
||||
timeout := time.After(50 * time.Millisecond)
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("didn't recv pkt %d after timeout", halfPackets+i)
|
||||
case pkt := <-mailbox.PacketOutBox():
|
||||
recvdPackets = append(recvdPackets, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
// Again, we should have received half of the total number of packets.
|
||||
if len(recvdPackets) != halfPackets {
|
||||
t.Fatalf("expected %v packets instead got %v",
|
||||
halfPackets, len(recvdPackets))
|
||||
}
|
||||
|
||||
// Check that the received packets are equal to the sent packets.
|
||||
if !reflect.DeepEqual(recvdPackets, sentPackets) {
|
||||
t.Fatalf("recvd packets mismatched: expected %v, got %v",
|
||||
spew.Sdump(sentPackets), spew.Sdump(recvdPackets))
|
||||
}
|
||||
}
|
||||
|
@ -627,13 +627,17 @@ func (f *mockChannelLink) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID }
|
||||
func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID }
|
||||
func (f *mockChannelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
|
||||
func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 }
|
||||
func (f *mockChannelLink) Peer() Peer { return f.peer }
|
||||
func (f *mockChannelLink) Stop() {}
|
||||
func (f *mockChannelLink) EligibleToForward() bool { return f.eligible }
|
||||
func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID }
|
||||
func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID }
|
||||
func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 }
|
||||
func (f *mockChannelLink) Peer() Peer { return f.peer }
|
||||
func (f *mockChannelLink) Stop() {}
|
||||
func (f *mockChannelLink) EligibleToForward() bool { return f.eligible }
|
||||
func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
|
||||
func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
|
||||
f.eligible = true
|
||||
return f.shortChanID, nil
|
||||
}
|
||||
|
||||
var _ ChannelLink = (*mockChannelLink)(nil)
|
||||
|
||||
|
@ -175,17 +175,20 @@ type Switch struct {
|
||||
// forward the settle/fail htlc updates back to the add htlc initiator.
|
||||
circuits CircuitMap
|
||||
|
||||
// mailMtx is a read/write mutex that protects the mailboxes map.
|
||||
mailMtx sync.RWMutex
|
||||
|
||||
// mailboxes is a map of channel id to mailboxes, which allows the
|
||||
// switch to buffer messages for peers that have not come back online.
|
||||
mailboxes map[lnwire.ShortChannelID]MailBox
|
||||
// mailOrchestrator manages the lifecycle of mailboxes used throughout
|
||||
// the switch, and facilitates delayed delivery of packets to links that
|
||||
// later come online.
|
||||
mailOrchestrator *mailOrchestrator
|
||||
|
||||
// indexMtx is a read/write mutex that protects the set of indexes
|
||||
// below.
|
||||
indexMtx sync.RWMutex
|
||||
|
||||
// pendingLinkIndex holds links that have not had their final, live
|
||||
// short_chan_id assigned. These links can be transitioned into the
|
||||
// primary linkIndex by using UpdateShortChanID to load their live id.
|
||||
pendingLinkIndex map[lnwire.ChannelID]ChannelLink
|
||||
|
||||
// links is a map of channel id and channel link which manages
|
||||
// this channel.
|
||||
linkIndex map[lnwire.ChannelID]ChannelLink
|
||||
@ -248,9 +251,10 @@ func New(cfg Config) (*Switch, error) {
|
||||
circuits: circuitMap,
|
||||
paymentSequencer: sequencer,
|
||||
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
||||
mailboxes: make(map[lnwire.ShortChannelID]MailBox),
|
||||
mailOrchestrator: newMailOrchestrator(),
|
||||
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
||||
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
||||
pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
||||
pendingPayments: make(map[uint64]*pendingPayment),
|
||||
htlcPlex: make(chan *plexPacket),
|
||||
chanCloseRequests: make(chan *ChanClose),
|
||||
@ -1089,8 +1093,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
|
||||
// Check to see that the source link is online before removing
|
||||
// the circuit.
|
||||
sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID)
|
||||
return sourceMailbox.AddPacket(packet)
|
||||
return s.mailOrchestrator.Deliver(packet.incomingChanID, packet)
|
||||
|
||||
default:
|
||||
return errors.New("wrong update type")
|
||||
@ -1116,16 +1119,18 @@ func (s *Switch) failAddPacket(packet *htlcPacket,
|
||||
|
||||
log.Error(failErr)
|
||||
|
||||
// Route a fail packet back to the source link.
|
||||
sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID)
|
||||
if err = sourceMailbox.AddPacket(&htlcPacket{
|
||||
failPkt := &htlcPacket{
|
||||
incomingChanID: packet.incomingChanID,
|
||||
incomingHTLCID: packet.incomingHTLCID,
|
||||
circuit: packet.circuit,
|
||||
htlc: &lnwire.UpdateFailHTLC{
|
||||
Reason: reason,
|
||||
},
|
||||
}); err != nil {
|
||||
}
|
||||
|
||||
// Route a fail packet back to the source link.
|
||||
err = s.mailOrchestrator.Deliver(failPkt.incomingChanID, failPkt)
|
||||
if err != nil {
|
||||
err = errors.Errorf("source chanid=%v unable to "+
|
||||
"handle switch packet: %v",
|
||||
packet.incomingChanID, err)
|
||||
@ -1343,6 +1348,12 @@ func (s *Switch) htlcForwarder() {
|
||||
"channel link on stop: %v", err)
|
||||
}
|
||||
}
|
||||
for _, link := range s.pendingLinkIndex {
|
||||
if err := s.removeLink(link.ChanID()); err != nil {
|
||||
log.Errorf("unable to remove pending "+
|
||||
"channel link on stop: %v", err)
|
||||
}
|
||||
}
|
||||
s.indexMtx.Unlock()
|
||||
|
||||
// Before we exit fully, we'll attempt to flush out any
|
||||
@ -1560,7 +1571,7 @@ func (s *Switch) reforwardResponses() error {
|
||||
}
|
||||
|
||||
for _, activeChannel := range activeChannels {
|
||||
shortChanID := activeChannel.ShortChanID
|
||||
shortChanID := activeChannel.ShortChanID()
|
||||
fwdPkgs, err := s.loadChannelFwdPkgs(shortChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1699,9 +1710,7 @@ func (s *Switch) Stop() error {
|
||||
// Wait until all active goroutines have finished exiting before
|
||||
// stopping the mailboxes, otherwise the mailbox map could still be
|
||||
// accessed and modified.
|
||||
for _, mailBox := range s.mailboxes {
|
||||
mailBox.Stop()
|
||||
}
|
||||
s.mailOrchestrator.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1712,64 +1721,67 @@ func (s *Switch) AddLink(link ChannelLink) error {
|
||||
s.indexMtx.Lock()
|
||||
defer s.indexMtx.Unlock()
|
||||
|
||||
// First we'll add the link to the linkIndex which lets us quickly look
|
||||
// up a channel when we need to close or register it, and the
|
||||
// forwarding index which'll be used when forwarding HTLC's in the
|
||||
// multi-hop setting.
|
||||
chanID := link.ChanID()
|
||||
|
||||
// Get and attach the mailbox for this link, which buffers packets in
|
||||
// case there packets that we tried to deliver while this link was
|
||||
// offline.
|
||||
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID)
|
||||
link.AttachMailBox(mailbox)
|
||||
|
||||
if err := link.Start(); err != nil {
|
||||
s.removeLink(chanID)
|
||||
return err
|
||||
}
|
||||
|
||||
shortChanID := link.ShortChanID()
|
||||
if shortChanID == sourceHop {
|
||||
log.Infof("Adding pending link chan_id=%v, short_chan_id=%v",
|
||||
chanID, shortChanID)
|
||||
|
||||
s.pendingLinkIndex[chanID] = link
|
||||
} else {
|
||||
log.Infof("Adding live link chan_id=%v, short_chan_id=%v",
|
||||
chanID, shortChanID)
|
||||
|
||||
s.addLiveLink(link)
|
||||
s.mailOrchestrator.BindLiveShortChanID(
|
||||
mailbox, chanID, shortChanID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLiveLink adds a link to all associated forwarding index, this makes it a
|
||||
// candidate for forwarding HTLCs.
|
||||
func (s *Switch) addLiveLink(link ChannelLink) {
|
||||
// We'll add the link to the linkIndex which lets us quickly
|
||||
// look up a channel when we need to close or register it, and
|
||||
// the forwarding index which'll be used when forwarding HTLC's
|
||||
// in the multi-hop setting.
|
||||
s.linkIndex[link.ChanID()] = link
|
||||
s.forwardingIndex[link.ShortChanID()] = link
|
||||
|
||||
// Next we'll add the link to the interface index so we can quickly
|
||||
// look up all the channels for a particular node.
|
||||
// Next we'll add the link to the interface index so we can
|
||||
// quickly look up all the channels for a particular node.
|
||||
peerPub := link.Peer().PubKey()
|
||||
if _, ok := s.interfaceIndex[peerPub]; !ok {
|
||||
s.interfaceIndex[peerPub] = make(map[ChannelLink]struct{})
|
||||
}
|
||||
s.interfaceIndex[peerPub][link] = struct{}{}
|
||||
|
||||
// Get the mailbox for this link, which buffers packets in case there
|
||||
// packets that we tried to deliver while this link was offline.
|
||||
mailbox := s.getOrCreateMailBox(link.ShortChanID())
|
||||
|
||||
// Give the link its mailbox, we only need to start the mailbox if it
|
||||
// wasn't previously found.
|
||||
link.AttachMailBox(mailbox)
|
||||
|
||||
if err := link.Start(); err != nil {
|
||||
s.removeLink(link.ChanID())
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Added channel link with chan_id=%v, short_chan_id=(%v)",
|
||||
link.ChanID(), spew.Sdump(link.ShortChanID()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreateMailBox returns the known mailbox for a particular short channel
|
||||
// id, or creates one if the link has no existing mailbox.
|
||||
func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox {
|
||||
// Check to see if we have a mailbox already populated for this link.
|
||||
s.mailMtx.RLock()
|
||||
mailbox, ok := s.mailboxes[chanID]
|
||||
if ok {
|
||||
s.mailMtx.RUnlock()
|
||||
return mailbox
|
||||
}
|
||||
s.mailMtx.RUnlock()
|
||||
// removeLiveLink removes a link from all associated forwarding indexes, this
|
||||
// prevents it from being a candidate in forwarding.
|
||||
func (s *Switch) removeLiveLink(link ChannelLink) {
|
||||
// Remove the channel from live link indexes.
|
||||
delete(s.linkIndex, link.ChanID())
|
||||
delete(s.forwardingIndex, link.ShortChanID())
|
||||
|
||||
// Otherwise, we will make a new one only if the mailbox still is not
|
||||
// present after the exclusive mutex is acquired.
|
||||
s.mailMtx.Lock()
|
||||
mailbox, ok = s.mailboxes[chanID]
|
||||
if !ok {
|
||||
mailbox = newMemoryMailBox()
|
||||
mailbox.Start()
|
||||
s.mailboxes[chanID] = mailbox
|
||||
}
|
||||
s.mailMtx.Unlock()
|
||||
|
||||
return mailbox
|
||||
// Remove the channel from channel index.
|
||||
peerPub := link.Peer().PubKey()
|
||||
delete(s.interfaceIndex, peerPub)
|
||||
}
|
||||
|
||||
// GetLink is used to initiate the handling of the get link command. The
|
||||
@ -1780,7 +1792,10 @@ func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) {
|
||||
|
||||
link, ok := s.linkIndex[chanID]
|
||||
if !ok {
|
||||
return nil, ErrChannelLinkNotFound
|
||||
link, ok = s.pendingLinkIndex[chanID]
|
||||
if !ok {
|
||||
return nil, ErrChannelLinkNotFound
|
||||
}
|
||||
}
|
||||
|
||||
return link, nil
|
||||
@ -1815,52 +1830,68 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
|
||||
log.Infof("Removing channel link with ChannelID(%v)", chanID)
|
||||
|
||||
link, ok := s.linkIndex[chanID]
|
||||
if !ok {
|
||||
return ErrChannelLinkNotFound
|
||||
if ok {
|
||||
s.removeLiveLink(link)
|
||||
link.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove the channel from channel map.
|
||||
delete(s.linkIndex, chanID)
|
||||
delete(s.forwardingIndex, link.ShortChanID())
|
||||
link, ok = s.pendingLinkIndex[chanID]
|
||||
if ok {
|
||||
delete(s.pendingLinkIndex, chanID)
|
||||
link.Stop()
|
||||
|
||||
// Remove the channel from channel index.
|
||||
peerPub := link.Peer().PubKey()
|
||||
delete(s.interfaceIndex, peerPub)
|
||||
return nil
|
||||
}
|
||||
|
||||
link.Stop()
|
||||
|
||||
return nil
|
||||
return ErrChannelLinkNotFound
|
||||
}
|
||||
|
||||
// UpdateShortChanID updates the short chan ID for an existing channel. This is
|
||||
// required in the case of a re-org and re-confirmation or a channel, or in the
|
||||
// case that a link was added to the switch before its short chan ID was known.
|
||||
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
|
||||
shortChanID lnwire.ShortChannelID) error {
|
||||
|
||||
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID) error {
|
||||
s.indexMtx.Lock()
|
||||
defer s.indexMtx.Unlock()
|
||||
|
||||
// First, we'll extract the current link as is from the link
|
||||
// index. If the link isn't even in the index, then we'll return an
|
||||
// error.
|
||||
link, ok := s.linkIndex[chanID]
|
||||
// Locate the target link in the pending link index. If no such link
|
||||
// exists, then we will ignore the request.
|
||||
link, ok := s.pendingLinkIndex[chanID]
|
||||
if !ok {
|
||||
s.indexMtx.Unlock()
|
||||
|
||||
return fmt.Errorf("link %v not found", chanID)
|
||||
}
|
||||
|
||||
log.Infof("Updating short_chan_id for ChannelLink(%v): old=%v, new=%v",
|
||||
chanID, link.ShortChanID(), shortChanID)
|
||||
oldShortChanID := link.ShortChanID()
|
||||
|
||||
// At this point the link is actually active, so we'll update the
|
||||
// forwarding index with the next short channel ID.
|
||||
s.forwardingIndex[shortChanID] = link
|
||||
// Try to update the link's short channel ID, returning early if this
|
||||
// update failed.
|
||||
shortChanID, err := link.UpdateShortChanID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.indexMtx.Unlock()
|
||||
// Reject any blank short channel ids.
|
||||
if shortChanID == sourceHop {
|
||||
return fmt.Errorf("refusing trivial short_chan_id for chan_id=%v"+
|
||||
"live link", chanID)
|
||||
}
|
||||
|
||||
// Finally, we'll notify the link of its new short channel ID.
|
||||
link.UpdateShortChanID(shortChanID)
|
||||
log.Infof("Updated short_chan_id for ChannelLink(%v): old=%v, new=%v",
|
||||
chanID, oldShortChanID, shortChanID)
|
||||
|
||||
// Since the link was in the pending state before, we will remove it
|
||||
// from the pending link index and add it to the live link index so that
|
||||
// it can be available in forwarding.
|
||||
delete(s.pendingLinkIndex, chanID)
|
||||
s.addLiveLink(link)
|
||||
|
||||
// Finally, alert the mail orchestrator to the change of short channel
|
||||
// ID, and deliver any unclaimed packets to the link.
|
||||
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID)
|
||||
s.mailOrchestrator.BindLiveShortChanID(
|
||||
mailbox, chanID, shortChanID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package htlcswitch
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
@ -24,6 +25,101 @@ func genPreimage() ([32]byte, error) {
|
||||
return preimage, nil
|
||||
}
|
||||
|
||||
// TestSwitchSendPending checks the inability of htlc switch to forward adds
|
||||
// over pending links, and the UpdateShortChanID makes a pending link live.
|
||||
func TestSwitchSendPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
alicePeer, err := newMockServer(t, "alice", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create alice server: %v", err)
|
||||
}
|
||||
|
||||
s, err := initSwitchWithDB(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init switch: %v", err)
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("unable to start switch: %v", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
chanID1, _, aliceChanID, bobChanID := genIDs()
|
||||
|
||||
pendingChanID := lnwire.ShortChannelID{}
|
||||
|
||||
aliceChannelLink := newMockChannelLink(
|
||||
s, chanID1, pendingChanID, alicePeer, false,
|
||||
)
|
||||
if err := s.AddLink(aliceChannelLink); err != nil {
|
||||
t.Fatalf("unable to add alice link: %v", err)
|
||||
}
|
||||
|
||||
// Create request which should is being forwarded from Bob channel
|
||||
// link to Alice channel link.
|
||||
preimage, err := genPreimage()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate preimage: %v", err)
|
||||
}
|
||||
rhash := fastsha256.Sum256(preimage[:])
|
||||
packet := &htlcPacket{
|
||||
incomingChanID: bobChanID,
|
||||
incomingHTLCID: 0,
|
||||
outgoingChanID: aliceChanID,
|
||||
obfuscator: NewMockObfuscator(),
|
||||
htlc: &lnwire.UpdateAddHTLC{
|
||||
PaymentHash: rhash,
|
||||
Amount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
// Send the ADD packet, this should not be forwarded out to the link
|
||||
// since there are no eligible links.
|
||||
err = s.forward(packet)
|
||||
expErr := fmt.Sprintf("unable to find link with destination %v",
|
||||
aliceChanID)
|
||||
if err != nil && err.Error() != expErr {
|
||||
t.Fatalf("expected forward failure: %v", err)
|
||||
}
|
||||
|
||||
// No message should be sent, since the packet was failed.
|
||||
select {
|
||||
case <-aliceChannelLink.packets:
|
||||
t.Fatal("expected not to receive message")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
// Since the packet should have been failed, there should be no active
|
||||
// circuits.
|
||||
if s.circuits.NumOpen() != 0 {
|
||||
t.Fatal("wrong amount of circuits")
|
||||
}
|
||||
|
||||
// Now, update Alice's link with her final short channel id. This should
|
||||
// move the link to the live state.
|
||||
aliceChannelLink.setLiveShortChanID(aliceChanID)
|
||||
err = s.UpdateShortChanID(chanID1)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to update alice short_chan_id: %v", err)
|
||||
}
|
||||
|
||||
// Increment the packet's HTLC index, so that it does not collide with
|
||||
// the prior attempt.
|
||||
packet.incomingHTLCID++
|
||||
|
||||
// Handle the request and checks that bob channel link received it.
|
||||
if err := s.forward(packet); err != nil {
|
||||
t.Fatalf("unexpected forward failure: %v", err)
|
||||
}
|
||||
|
||||
// Since Alice's link is now active, this packet should succeed.
|
||||
select {
|
||||
case <-aliceChannelLink.packets:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request was not propagated to alice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSwitchForward checks the ability of htlc switch to forward add/settle
|
||||
// requests.
|
||||
func TestSwitchForward(t *testing.T) {
|
||||
|
@ -324,7 +324,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
RevocationStore: shachain.NewRevocationStore(),
|
||||
LocalCommitment: aliceCommit,
|
||||
RemoteCommitment: aliceCommit,
|
||||
ShortChanID: chanID,
|
||||
ShortChannelID: chanID,
|
||||
Db: dbAlice,
|
||||
Packager: channeldb.NewChannelPackager(chanID),
|
||||
FundingTxn: testTx,
|
||||
@ -343,7 +343,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
RevocationStore: shachain.NewRevocationStore(),
|
||||
LocalCommitment: bobCommit,
|
||||
RemoteCommitment: bobCommit,
|
||||
ShortChanID: chanID,
|
||||
ShortChannelID: chanID,
|
||||
Db: dbBob,
|
||||
Packager: channeldb.NewChannelPackager(chanID),
|
||||
}
|
||||
|
6
lnd.go
6
lnd.go
@ -460,11 +460,9 @@ func lndMain() error {
|
||||
// the chain arb so it can react to on-chain events.
|
||||
return server.chainArb.WatchNewChannel(channel)
|
||||
},
|
||||
ReportShortChanID: func(chanPoint wire.OutPoint,
|
||||
sid lnwire.ShortChannelID) error {
|
||||
|
||||
ReportShortChanID: func(chanPoint wire.OutPoint) error {
|
||||
cid := lnwire.NewChanIDFromOutPoint(&chanPoint)
|
||||
return server.htlcSwitch.UpdateShortChanID(cid, sid)
|
||||
return server.htlcSwitch.UpdateShortChanID(cid)
|
||||
},
|
||||
RequiredRemoteChanReserve: func(chanAmt btcutil.Amount) btcutil.Amount {
|
||||
// By default, we'll require the remote peer to maintain
|
||||
|
@ -4178,7 +4178,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
|
||||
}
|
||||
}
|
||||
|
||||
source := lc.channelState.ShortChanID
|
||||
source := lc.ShortChanID()
|
||||
|
||||
// Now that we have gathered the set of HTLCs to forward, separated by
|
||||
// type, construct a forwarding package using the height that the remote
|
||||
@ -4356,7 +4356,7 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte,
|
||||
htlc := lc.remoteUpdateLog.lookupHtlc(htlcIndex)
|
||||
if htlc == nil {
|
||||
return fmt.Errorf("No HTLC with ID %d in channel %v", htlcIndex,
|
||||
lc.channelState.ShortChanID)
|
||||
lc.ShortChanID())
|
||||
}
|
||||
|
||||
if htlc.RHash != sha256.Sum256(preimage[:]) {
|
||||
@ -4391,7 +4391,7 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint6
|
||||
htlc := lc.localUpdateLog.lookupHtlc(htlcIndex)
|
||||
if htlc == nil {
|
||||
return fmt.Errorf("No HTLC with ID %d in channel %v", htlcIndex,
|
||||
lc.channelState.ShortChanID)
|
||||
lc.ShortChanID())
|
||||
}
|
||||
|
||||
if htlc.RHash != sha256.Sum256(preimage[:]) {
|
||||
@ -4445,7 +4445,7 @@ func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte,
|
||||
htlc := lc.remoteUpdateLog.lookupHtlc(htlcIndex)
|
||||
if htlc == nil {
|
||||
return fmt.Errorf("No HTLC with ID %d in channel %v", htlcIndex,
|
||||
lc.channelState.ShortChanID)
|
||||
lc.ShortChanID())
|
||||
}
|
||||
|
||||
pd := &PaymentDescriptor{
|
||||
@ -4485,7 +4485,7 @@ func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64,
|
||||
htlc := lc.remoteUpdateLog.lookupHtlc(htlcIndex)
|
||||
if htlc == nil {
|
||||
return fmt.Errorf("No HTLC with ID %d in channel %v", htlcIndex,
|
||||
lc.channelState.ShortChanID)
|
||||
lc.ShortChanID())
|
||||
}
|
||||
|
||||
pd := &PaymentDescriptor{
|
||||
@ -4518,7 +4518,7 @@ func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte,
|
||||
htlc := lc.localUpdateLog.lookupHtlc(htlcIndex)
|
||||
if htlc == nil {
|
||||
return fmt.Errorf("No HTLC with ID %d in channel %v", htlcIndex,
|
||||
lc.channelState.ShortChanID)
|
||||
lc.ShortChanID())
|
||||
}
|
||||
|
||||
pd := &PaymentDescriptor{
|
||||
@ -4546,7 +4546,7 @@ func (lc *LightningChannel) ChannelPoint() *wire.OutPoint {
|
||||
// ID encodes the exact location in the main chain that the original
|
||||
// funding output can be found.
|
||||
func (lc *LightningChannel) ShortChanID() lnwire.ShortChannelID {
|
||||
return lc.channelState.ShortChanID
|
||||
return lc.channelState.ShortChanID()
|
||||
}
|
||||
|
||||
// genHtlcScript generates the proper P2WSH public key scripts for the HTLC
|
||||
|
@ -270,7 +270,7 @@ func CreateTestChannels() (*LightningChannel, *LightningChannel, func(), error)
|
||||
RemoteChanCfg: bobCfg,
|
||||
IdentityPub: aliceKeys[0].PubKey(),
|
||||
FundingOutpoint: *prevOut,
|
||||
ShortChanID: shortChanID,
|
||||
ShortChannelID: shortChanID,
|
||||
ChanType: channeldb.SingleFunder,
|
||||
IsInitiator: true,
|
||||
Capacity: channelCapacity,
|
||||
@ -288,7 +288,7 @@ func CreateTestChannels() (*LightningChannel, *LightningChannel, func(), error)
|
||||
RemoteChanCfg: aliceCfg,
|
||||
IdentityPub: bobKeys[0].PubKey(),
|
||||
FundingOutpoint: *prevOut,
|
||||
ShortChanID: shortChanID,
|
||||
ShortChannelID: shortChanID,
|
||||
ChanType: channeldb.SingleFunder,
|
||||
IsInitiator: false,
|
||||
Capacity: channelCapacity,
|
||||
|
@ -371,7 +371,7 @@ func TestCommitmentAndHTLCTransactions(t *testing.T) {
|
||||
ChanType: channeldb.SingleFunder,
|
||||
ChainHash: *tc.netParams.GenesisHash,
|
||||
FundingOutpoint: tc.fundingOutpoint,
|
||||
ShortChanID: tc.shortChanID,
|
||||
ShortChannelID: tc.shortChanID,
|
||||
IsInitiator: true,
|
||||
IdentityPub: identityKey,
|
||||
LocalChanCfg: channeldb.ChannelConfig{
|
||||
|
2
pilot.go
2
pilot.go
@ -171,7 +171,7 @@ func initAutoPilot(svr *server, cfg *autoPilotConfig) (*autopilot.Agent, error)
|
||||
initialChanState := make([]autopilot.Channel, len(activeChannels))
|
||||
for i, channel := range activeChannels {
|
||||
initialChanState[i] = autopilot.Channel{
|
||||
ChanID: channel.ShortChanID,
|
||||
ChanID: channel.ShortChanID(),
|
||||
Capacity: channel.Capacity,
|
||||
Node: autopilot.NewNodeID(channel.IdentityPub),
|
||||
}
|
||||
|
@ -2306,7 +2306,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
|
||||
}
|
||||
|
||||
// Fetch the policies for each end of the channel.
|
||||
chanID := channel.ShortChanID.ToUint64()
|
||||
chanID := channel.ShortChanID().ToUint64()
|
||||
_, p1, p2, err := graph.FetchChannelEdgesByID(chanID)
|
||||
if err != nil {
|
||||
rpcsLog.Errorf("Unable to fetch the routing "+
|
||||
|
@ -242,7 +242,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
|
||||
RemoteChanCfg: bobCfg,
|
||||
IdentityPub: aliceKeyPub,
|
||||
FundingOutpoint: *prevOut,
|
||||
ShortChanID: shortChanID,
|
||||
ShortChannelID: shortChanID,
|
||||
ChanType: channeldb.SingleFunder,
|
||||
IsInitiator: true,
|
||||
Capacity: channelCapacity,
|
||||
|
Loading…
x
Reference in New Issue
Block a user