diff --git a/docs/release-notes/release-notes-0.16.1.md b/docs/release-notes/release-notes-0.16.1.md index 983f4f358..2b5cdc801 100644 --- a/docs/release-notes/release-notes-0.16.1.md +++ b/docs/release-notes/release-notes-0.16.1.md @@ -9,6 +9,9 @@ * [Allow caller to filter sessions at the time of reading them from disk](https://github.com/lightningnetwork/lnd/pull/7059) +* [Clean up sessions once all channels for which they have updates for are + closed. Also start sending the `DeleteSession` message to the + tower.](https://github.com/lightningnetwork/lnd/pull/7069) ## Misc diff --git a/itest/list_on_test.go b/itest/list_on_test.go index d2253969f..19f083cbb 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -515,4 +515,8 @@ var allTestCases = []*lntest.TestCase{ Name: "lookup htlc resolution", TestFunc: testLookupHtlcResolution, }, + { + Name: "watchtower session management", + TestFunc: testWatchtowerSessionManagement, + }, } diff --git a/itest/lnd_watchtower_test.go b/itest/lnd_watchtower_test.go new file mode 100644 index 000000000..3432848d8 --- /dev/null +++ b/itest/lnd_watchtower_test.go @@ -0,0 +1,172 @@ +package itest + +import ( + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lnrpc/wtclientrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/node" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/stretchr/testify/require" +) + +// testWatchtowerSessionManagement tests that session deletion is done +// correctly. +func testWatchtowerSessionManagement(ht *lntest.HarnessTest) { + const ( + chanAmt = funding.MaxBtcFundingAmount + paymentAmt = 10_000 + numInvoices = 5 + maxUpdates = numInvoices * 2 + externalIP = "1.2.3.4" + sessionCloseRange = 1 + ) + + // Set up Wallis the watchtower who will be used by Dave to watch over + // his channel commitment transactions. + wallis := ht.NewNode("Wallis", []string{ + "--watchtower.active", + "--watchtower.externalip=" + externalIP, + }) + + wallisInfo := wallis.RPC.GetInfoWatchtower() + + // Assert that Wallis has one listener and it is 0.0.0.0:9911 or + // [::]:9911. Since no listener is explicitly specified, one of these + // should be the default depending on whether the host supports IPv6 or + // not. + require.Len(ht, wallisInfo.Listeners, 1) + listener := wallisInfo.Listeners[0] + require.True(ht, listener == "0.0.0.0:9911" || listener == "[::]:9911") + + // Assert the Wallis's URIs properly display the chosen external IP. + require.Len(ht, wallisInfo.Uris, 1) + require.Contains(ht, wallisInfo.Uris[0], externalIP) + + // Dave will be the tower client. + daveArgs := []string{ + "--wtclient.active", + fmt.Sprintf("--wtclient.max-updates=%d", maxUpdates), + fmt.Sprintf( + "--wtclient.session-close-range=%d", sessionCloseRange, + ), + } + dave := ht.NewNode("Dave", daveArgs) + + addTowerReq := &wtclientrpc.AddTowerRequest{ + Pubkey: wallisInfo.Pubkey, + Address: listener, + } + dave.RPC.AddTower(addTowerReq) + + // Assert that there exists a session between Dave and Wallis. + err := wait.NoError(func() error { + info := dave.RPC.GetTowerInfo(&wtclientrpc.GetTowerInfoRequest{ + Pubkey: wallisInfo.Pubkey, + IncludeSessions: true, + }) + + var numSessions uint32 + for _, sessionType := range info.SessionInfo { + numSessions += sessionType.NumSessions + } + if numSessions > 0 { + return nil + } + + return fmt.Errorf("expected a non-zero number of sessions") + }, defaultTimeout) + require.NoError(ht, err) + + // Before we make a channel, we'll load up Dave with some coins sent + // directly from the miner. + ht.FundCoins(btcutil.SatoshiPerBitcoin, dave) + + // Connect Dave and Alice. + ht.ConnectNodes(dave, ht.Alice) + + // Open a channel between Dave and Alice. + params := lntest.OpenChannelParams{ + Amt: chanAmt, + } + chanPoint := ht.OpenChannel(dave, ht.Alice, params) + + // Since there are 2 updates made for every payment and the maximum + // number of updates per session has been set to 10, make 5 payments + // between the pair so that the session is exhausted. + alicePayReqs, _, _ := ht.CreatePayReqs( + ht.Alice, paymentAmt, numInvoices, + ) + + send := func(node *node.HarnessNode, payReq string) { + stream := node.RPC.SendPayment(&routerrpc.SendPaymentRequest{ + PaymentRequest: payReq, + TimeoutSeconds: 60, + FeeLimitMsat: noFeeLimitMsat, + }) + + ht.AssertPaymentStatusFromStream( + stream, lnrpc.Payment_SUCCEEDED, + ) + } + + for i := 0; i < numInvoices; i++ { + send(dave, alicePayReqs[i]) + } + + // assertNumBackups is a closure that asserts that Dave has a certain + // number of backups backed up to the tower. If mineOnFail is true, + // then a block will be mined each time the assertion fails. + assertNumBackups := func(expected int, mineOnFail bool) { + err = wait.NoError(func() error { + info := dave.RPC.GetTowerInfo( + &wtclientrpc.GetTowerInfoRequest{ + Pubkey: wallisInfo.Pubkey, + IncludeSessions: true, + }, + ) + + var numBackups uint32 + for _, sessionType := range info.SessionInfo { + for _, session := range sessionType.Sessions { + numBackups += session.NumBackups + } + } + + if numBackups == uint32(expected) { + return nil + } + + if mineOnFail { + ht.Miner.MineBlocksSlow(1) + } + + return fmt.Errorf("expected %d backups, got %d", + expected, numBackups) + }, defaultTimeout) + require.NoError(ht, err) + } + + // Assert that one of the sessions now has 10 backups. + assertNumBackups(10, false) + + // Now close the channel and wait for the close transaction to appear + // in the mempool so that it is included in a block when we mine. + ht.CloseChannelAssertPending(dave, chanPoint, false) + + // Mine enough blocks to surpass the session-close-range. This should + // trigger the session to be deleted. + ht.MineBlocksAndAssertNumTxes(sessionCloseRange+6, 1) + + // Wait for the session to be deleted. We know it has been deleted once + // the number of backups is back to zero. We check for number of backups + // instead of number of sessions because it is expected that the client + // would immediately negotiate another session after deleting the + // exhausted one. This time we set the "mineOnFail" parameter to true to + // ensure that the session deleting logic is run. + assertNumBackups(0, true) +} diff --git a/lncfg/wtclient.go b/lncfg/wtclient.go index 8b9f03939..feaae464c 100644 --- a/lncfg/wtclient.go +++ b/lncfg/wtclient.go @@ -17,6 +17,15 @@ type WtClient struct { // SweepFeeRate specifies the fee rate in sat/byte to be used when // constructing justice transactions sent to the tower. SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."` + + // SessionCloseRange is the range over which to choose a random number + // of blocks to wait after the last channel of a session is closed + // before sending the DeleteSession message to the tower server. + SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."` + + // MaxUpdates is the maximum number of updates to be backed up in a + // single tower sessions. + MaxUpdates uint16 `long:"max-updates" description:"The maximum number of updates to be backed up in a single session."` } // Validate ensures the user has provided a valid configuration. diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 415ae1702..335c370c8 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -309,6 +309,7 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, } t.SessionInfo = append(t.SessionInfo, rpcTower.SessionInfo...) + t.Sessions = append(t.Sessions, rpcTower.Sessions...) } towers := make([]*Tower, 0, len(rpcTowers)) @@ -365,6 +366,9 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context, rpcTower.SessionInfo = append( rpcTower.SessionInfo, rpcLegacyTower.SessionInfo..., ) + rpcTower.Sessions = append( + rpcTower.Sessions, rpcLegacyTower.Sessions..., + ) return rpcTower, nil } diff --git a/lntest/rpc/watchtower.go b/lntest/rpc/watchtower.go index 1b05d15ea..d5512e7a4 100644 --- a/lntest/rpc/watchtower.go +++ b/lntest/rpc/watchtower.go @@ -24,6 +24,20 @@ func (h *HarnessRPC) GetInfoWatchtower() *watchtowerrpc.GetInfoResponse { return info } +// GetTowerInfo makes an RPC call to the watchtower client of the given node and +// asserts. +func (h *HarnessRPC) GetTowerInfo( + req *wtclientrpc.GetTowerInfoRequest) *wtclientrpc.Tower { + + ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout) + defer cancel() + + info, err := h.WatchtowerClient.GetTowerInfo(ctxt, req) + h.NoError(err, "GetTowerInfo from WatchtowerClient") + + return info +} + // AddTower makes a RPC call to the WatchtowerClient of the given node and // asserts. func (h *HarnessRPC) AddTower( diff --git a/sample-lnd.conf b/sample-lnd.conf index 3dfc76da8..b0499294b 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -997,6 +997,15 @@ litecoin.node=ltcd ; supported at this time, if none are provided the tower will not be enabled. ; wtclient.private-tower-uris= +; The range over which to choose a random number of blocks to wait after the +; last channel of a session is closed before sending the DeleteSession message +; to the tower server. The default is currently 288. Note that setting this to +; a lower value will result in faster session cleanup _but_ that this comes +; along with reduced privacy from the tower server. +; wtclient.session-close-range=10 + +; The maximum number of updates to include in a tower session. +; wtclient.max-updates=1024 [healthcheck] diff --git a/server.go b/server.go index 7758ebe8c..2224976e5 100644 --- a/server.go +++ b/server.go @@ -1497,6 +1497,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr, policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight() } + if cfg.WtClient.MaxUpdates != 0 { + policy.MaxUpdates = cfg.WtClient.MaxUpdates + } + + sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange) + if cfg.WtClient.SessionCloseRange != 0 { + sessionCloseRange = cfg.WtClient.SessionCloseRange + } + if err := policy.Validate(); err != nil { return nil, err } @@ -1512,7 +1521,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) } + fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID + s.towerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, @@ -1536,6 +1556,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go index 87065c011..cb16d335a 100644 --- a/watchtower/wtclient/addr_iterator.go +++ b/watchtower/wtclient/addr_iterator.go @@ -69,6 +69,12 @@ type AddressIterator interface { // Reset clears the iterators state, and makes the address at the front // of the list the next item to be returned. Reset() + + // Copy constructs a new AddressIterator that has the same addresses + // as this iterator. + // + // NOTE that the address locks are not expected to be copied. + Copy() AddressIterator } // A compile-time check to ensure that addressIterator implements the @@ -324,6 +330,33 @@ func (a *addressIterator) GetAll() []net.Addr { a.mu.Lock() defer a.mu.Unlock() + return a.getAllUnsafe() +} + +// Copy constructs a new AddressIterator that has the same addresses +// as this iterator. +// +// NOTE that the address locks will not be copied. +func (a *addressIterator) Copy() AddressIterator { + a.mu.Lock() + defer a.mu.Unlock() + + addrs := a.getAllUnsafe() + + // Since newAddressIterator will only ever return an error if it is + // initialised with zero addresses, we can ignore the error here since + // we are initialising it with the set of addresses of this + // addressIterator which is by definition a non-empty list. + iter, _ := newAddressIterator(addrs...) + + return iter +} + +// getAllUnsafe returns a copy of all the addresses in the iterator. +// +// NOTE: this method is not thread safe and so must only be called once the +// addressIterator mutex is already being held. +func (a *addressIterator) getAllUnsafe() []net.Addr { var addrs []net.Addr cursor := a.addrList.Front() diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go index d3674d985..89a35c4cb 100644 --- a/watchtower/wtclient/addr_iterator_test.go +++ b/watchtower/wtclient/addr_iterator_test.go @@ -97,6 +97,11 @@ func TestAddrIterator(t *testing.T) { addrList := iter.GetAll() require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3}) + // Also check that an iterator constructed via the Copy method, also + // contains all the expected addresses. + newIterAddrs := iter.Copy().GetAll() + require.ElementsMatch(t, newIterAddrs, []net.Addr{addr1, addr2, addr3}) + // Let's now remove addr3. err = iter.Remove(addr3) require.NoError(t, err) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index faf3169c6..10ef86465 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -29,6 +29,10 @@ type TowerCandidateIterator interface { // candidates available as long as they remain in the set. Reset() error + // GetTower gets the tower with the given ID from the iterator. If no + // such tower is found then ErrTowerNotInIterator is returned. + GetTower(id wtdb.TowerID) (*Tower, error) + // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. @@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error { return nil } +// GetTower gets the tower with the given ID from the iterator. If no such tower +// is found then ErrTowerNotInIterator is returned. +func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) { + t.mu.Lock() + defer t.mu.Unlock() + + tower, ok := t.candidates[id] + if !ok { + return nil, ErrTowerNotInIterator + } + + return tower, nil +} + // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 7fe6ba723..b4df80f4d 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -52,14 +52,10 @@ func randTower(t *testing.T) *Tower { func copyTower(t *testing.T, tower *Tower) *Tower { t.Helper() - addrs := tower.Addresses.GetAll() - addrIterator, err := newAddressIterator(addrs...) - require.NoError(t, err) - return &Tower{ ID: tower.ID, IdentityKey: tower.IdentityKey, - Addresses: addrIterator, + Addresses: tower.Addresses.Copy(), } } @@ -83,9 +79,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { tower, err := i.Next() require.NoError(t, err) - require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) - require.Equal(t, tower.ID, c.ID) - require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) + assertTowersEqual(t, c, tower) +} + +func assertTowersEqual(t *testing.T, expected, actual *Tower) { + t.Helper() + + require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey)) + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -155,4 +157,16 @@ func TestTowerCandidateIterator(t *testing.T) { towerIterator.AddCandidate(secondTower) assertActiveCandidate(t, towerIterator, secondTower, true) assertNextCandidate(t, towerIterator, secondTower) + + // Assert that the GetTower correctly returns the tower too. + tower, err := towerIterator.GetTower(secondTower.ID) + require.NoError(t, err) + assertTowersEqual(t, secondTower, tower) + + // Now remove the tower and assert that GetTower returns expected error. + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) + + _, err = towerIterator.GetTower(secondTower.ID) + require.ErrorIs(t, err, ErrTowerNotInIterator) } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 32622b7da..e92b8b4cf 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,7 +2,10 @@ package wtclient import ( "bytes" + "crypto/rand" + "errors" "fmt" + "math/big" "net" "sync" "time" @@ -11,11 +14,14 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -40,6 +46,11 @@ const ( // client should abandon any pending updates or session negotiations // before terminating. DefaultForceQuitDelay = 10 * time.Second + + // DefaultSessionCloseRange is the range over which we will generate a + // random number of blocks to delay closing a session after its last + // channel has been closed. + DefaultSessionCloseRange = 288 ) // genSessionFilter constructs a filter that can be used to select sessions only @@ -146,6 +157,19 @@ type Config struct { // transaction. Signer input.Signer + // SubscribeChannelEvents can be used to subscribe to channel event + // notifications. + SubscribeChannelEvents func() (subscribe.Subscription, error) + + // FetchClosedChannel can be used to fetch the info about a closed + // channel. If the channel is not found or not yet closed then + // channeldb.ErrClosedChannelNotFound will be returned. + FetchClosedChannel func(cid lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) + + // ChainNotifier can be used to subscribe to block notifications. + ChainNotifier chainntnfs.ChainNotifier + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -201,6 +225,11 @@ type Config struct { // watchtowers. If the exponential backoff produces a timeout greater // than this value, the backoff will be clamped to MaxBackoff. MaxBackoff time.Duration + + // SessionCloseRange is the range over which we will generate a random + // number of blocks to delay closing a session after its last channel + // has been closed. + SessionCloseRange uint32 } // newTowerMsg is an internal message we'll use within the TowerClient to signal @@ -258,6 +287,8 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *backupTask + closableSessionQueue *sessionCloseMinHeap + backupMu sync.Mutex summaries wtdb.ChannelSummaries chanCommitHeights map[lnwire.ChannelID]uint64 @@ -269,6 +300,7 @@ type TowerClient struct { staleTowers chan *staleTowerMsg wg sync.WaitGroup + quit chan struct{} forceQuit chan struct{} } @@ -308,17 +340,19 @@ func New(config *Config) (*TowerClient, error) { } c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: newTaskPipeline(plog), - chanCommitHeights: make(map[lnwire.ChannelID]uint64), - activeSessions: make(sessionQueueSet), - summaries: chanSummaries, - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(ClientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - forceQuit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: newTaskPipeline(plog), + chanCommitHeights: make(map[lnwire.ChannelID]uint64), + activeSessions: make(sessionQueueSet), + summaries: chanSummaries, + closableSessionQueue: newSessionCloseMinHeap(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(ClientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + forceQuit: make(chan struct{}), + quit: make(chan struct{}), } // perUpdate is a callback function that will be used to inspect the @@ -364,7 +398,7 @@ func New(config *Config) (*TowerClient, error) { return } - log.Infof("Using private watchtower %s, offering policy %s", + c.log.Infof("Using private watchtower %s, offering policy %s", tower, cfg.Policy) // Add the tower to the set of candidate towers. @@ -435,27 +469,19 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, } for _, s := range sessions { - towerKeyDesc, err := keyRing.DeriveKey( - keychain.KeyLocator{ - Family: keychain.KeyFamilyTowerSession, - Index: s.KeyIndex, - }, + if !sessionFilter(s) { + continue + } + + cs, err := NewClientSessionFromDBSession( + s, tower, keyRing, ) if err != nil { return nil, err } - sessionKeyECDH := keychain.NewPubKeyECDH( - towerKeyDesc, keyRing, - ) - // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = &ClientSession{ - ID: s.ID, - ClientSessionBody: s.ClientSessionBody, - Tower: tower, - SessionKeyECDH: sessionKeyECDH, - } + candidateSessions[s.ID] = cs perActiveTower(tower) } @@ -548,10 +574,70 @@ func (c *TowerClient) Start() error { } } + chanSub, err := c.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + return + } + + // Iterate over the list of registered channels and check if + // any of them can be marked as closed. + for id := range c.summaries { + isClosed, closedHeight, err := c.isChannelClosed(id) + if err != nil { + returnErr = err + return + } + + if !isClosed { + continue + } + + _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + c.log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(c.summaries, id) + } + + // Load all closable sessions. + closableSessions, err := c.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + return + } + + err = c.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + return + } + + c.wg.Add(1) + go c.handleChannelCloses(chanSub) + + // Subscribe to new block events. + blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + return + } + + c.wg.Add(1) + go c.handleClosableSessions(blockEvents) + // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. - err := c.negotiator.Start() + err = c.negotiator.Start() if err != nil { returnErr = err return @@ -599,6 +685,7 @@ func (c *TowerClient) Stop() error { // dispatcher to exit. The backup queue will signal it's // completion to the dispatcher, which releases the wait group // after all tasks have been assigned to session queues. + close(c.quit) c.wg.Wait() // 4. Since all valid tasks have been assigned to session @@ -780,6 +867,335 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { + defer c.wg.Done() + + c.log.Debugf("Starting channel close handler") + defer c.log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + c.log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + c.log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := c.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + c.log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. +func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + c.backupMu.Lock() + defer c.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := c.summaries[chanID]; !ok { + return nil + } + + c.log.Debugf("Marking channel(%s) as closed", chanID) + + sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + c.log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = c.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + + delete(c.summaries, chanID) + delete(c.chanCommitHeights, chanID) + + return nil +} + +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (c *TowerClient) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer c.wg.Done() + + c.log.Debug("Starting closable sessions handler") + defer c.log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-c.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := c.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + c.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := c.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + c.log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + err = c.deleteSessionFromTower(sess) + if err != nil { + c.log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = c.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + c.log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (c *TowerClient) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(c.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + c.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + +// deleteSessionFromTower dials the tower that we created the session with and +// attempts to send the tower the DeleteSession message. +func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { + // First, we check if we have already loaded this tower in our + // candidate towers iterator. + tower, err := c.candidateTowers.GetTower(sess.TowerID) + if errors.Is(err, ErrTowerNotInIterator) { + // If not, then we attempt to load it from the DB. + dbTower, err := c.cfg.DB.LoadTowerByID(sess.TowerID) + if err != nil { + return err + } + + tower, err = NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + } else if err != nil { + return err + } + + session, err := NewClientSessionFromDBSession( + sess, tower, c.cfg.SecretKeyRing, + ) + if err != nil { + return err + } + + localInit := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), + c.cfg.ChainHash, + ) + + var ( + conn wtserver.Peer + + // addrIterator is a copy of the tower's address iterator. + // We use this copy so that iterating through the addresses does + // not affect any other threads using this iterator. + addrIterator = tower.Addresses.Copy() + towerAddr = addrIterator.Peek() + ) + // Attempt to dial the tower with its available addresses. + for { + conn, err = c.dial( + session.SessionKeyECDH, &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: towerAddr, + }, + ) + if err != nil { + // If there are more addrs available, immediately try + // those. + nextAddr, iteratorErr := addrIterator.Next() + if iteratorErr == nil { + towerAddr = nextAddr + continue + } + + // Otherwise, if we have exhausted the address list, + // exit. + addrIterator.Reset() + + return fmt.Errorf("failed to dial tower(%x) at any "+ + "available addresses", + tower.IdentityKey.SerializeCompressed()) + } + + break + } + defer conn.Close() + + // Send Init to tower. + err = c.sendMessage(conn, localInit) + if err != nil { + return err + } + + // Receive Init from tower. + remoteMsg, err := c.readMessage(conn) + if err != nil { + return err + } + + remoteInit, ok := remoteMsg.(*wtwire.Init) + if !ok { + return fmt.Errorf("watchtower %s responded with %T to Init", + towerAddr, remoteMsg) + } + + // Validate Init. + err = localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + if err != nil { + return err + } + + // Send DeleteSession to tower. + err = c.sendMessage(conn, &wtwire.DeleteSession{}) + if err != nil { + return err + } + + // Receive DeleteSessionReply from tower. + remoteMsg, err = c.readMessage(conn) + if err != nil { + return err + } + + deleteSessionReply, ok := remoteMsg.(*wtwire.DeleteSessionReply) + if !ok { + return fmt.Errorf("watchtower %s responded with %T to "+ + "DeleteSession", towerAddr, remoteMsg) + } + + switch deleteSessionReply.Code { + case wtwire.CodeOK, wtwire.DeleteSessionCodeNotFound: + return nil + default: + return fmt.Errorf("received error code %v in "+ + "DeleteSessionReply when attempting to delete "+ + "session from tower", deleteSessionReply.Code) + } +} + // backupDispatcher processes events coming from the taskPipeline and is // responsible for detecting when the client needs to renegotiate a session to // fulfill continuing demand. The event loop exits after all tasks have been @@ -1153,6 +1569,22 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := c.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + // AddTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and @@ -1409,3 +1841,15 @@ func (c *TowerClient) logMessage( preposition, peer.RemotePub().SerializeCompressed(), peer.RemoteAddr()) } + +func newRandomDelay(max uint32) (uint32, error) { + var maxDelay big.Int + maxDelay.SetUint64(uint64(max)) + + randDelay, err := rand.Int(rand.Reader, &maxDelay) + if err != nil { + return 0, err + } + + return uint32(randDelay.Uint64()), nil +} diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 1490c6d10..2657e691b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1,6 +1,7 @@ package wtclient_test import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -15,12 +16,15 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" @@ -393,8 +397,15 @@ type testHarness struct { server *wtserver.Server net *mockNet - mu sync.Mutex - channels map[lnwire.ChannelID]*mockChannel + blockEvents *mockBlockSub + height int32 + + channelEvents *mockSubscription + sendUpdatesOn bool + + mu sync.Mutex + channels map[lnwire.ChannelID]*mockChannel + closedChannels map[lnwire.ChannelID]uint32 quit chan struct{} } @@ -441,41 +452,63 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { mockNet := newMockNet() clientDB := wtmock.NewClientDB() - clientCfg := &wtclient.Config{ - Signer: signer, - Dial: mockNet.Dial, - DB: clientDB, - AuthDial: mockNet.AuthDial, - SecretKeyRing: wtmock.NewSecretKeyRing(), - Policy: cfg.policy, - NewAddress: func() ([]byte, error) { - return addrScript, nil - }, - ReadTimeout: timeout, - WriteTimeout: timeout, - MinBackoff: time.Millisecond, - MaxBackoff: time.Second, - ForceQuitDelay: 10 * time.Second, - } - h := &testHarness{ - t: t, - cfg: cfg, - signer: signer, - capacity: cfg.localBalance + cfg.remoteBalance, - clientDB: clientDB, - clientCfg: clientCfg, - serverAddr: towerAddr, - serverDB: serverDB, - serverCfg: serverCfg, - net: mockNet, - channels: make(map[lnwire.ChannelID]*mockChannel), - quit: make(chan struct{}), + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + serverAddr: towerAddr, + serverDB: serverDB, + serverCfg: serverCfg, + net: mockNet, + blockEvents: newMockBlockSub(t), + channelEvents: newMockSubscription(t), + channels: make(map[lnwire.ChannelID]*mockChannel), + closedChannels: make(map[lnwire.ChannelID]uint32), + quit: make(chan struct{}), } t.Cleanup(func() { close(h.quit) }) + fetchChannel := func(id lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + height, ok := h.closedChannels[id] + if !ok { + return nil, channeldb.ErrClosedChannelNotFound + } + + return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil + } + + h.clientCfg = &wtclient.Config{ + Signer: signer, + SubscribeChannelEvents: func() (subscribe.Subscription, error) { + return h.channelEvents, nil + }, + FetchClosedChannel: fetchChannel, + ChainNotifier: h.blockEvents, + Dial: mockNet.Dial, + DB: clientDB, + AuthDial: mockNet.AuthDial, + SecretKeyRing: wtmock.NewSecretKeyRing(), + Policy: cfg.policy, + NewAddress: func() ([]byte, error) { + return addrScript, nil + }, + ReadTimeout: timeout, + WriteTimeout: timeout, + MinBackoff: time.Millisecond, + MaxBackoff: time.Second, + ForceQuitDelay: 10 * time.Second, + SessionCloseRange: 1, + } + if !cfg.noServerStart { h.startServer() t.Cleanup(h.stopServer) @@ -492,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return h } +// mine mimics the mining of new blocks by sending new block notifications. +func (h *testHarness) mine(numBlocks int) { + h.t.Helper() + + for i := 0; i < numBlocks; i++ { + h.height++ + h.blockEvents.sendNewBlock(h.height) + } +} + // startServer creates a new server using the harness's current serverCfg and // starts it after pointing the mockNet's callback to the new server. func (h *testHarness) startServer() { @@ -576,6 +619,41 @@ func (h *testHarness) channel(id uint64) *mockChannel { return c } +// closeChannel marks a channel as closed. +// +// NOTE: The method fails if a channel for id does not exist. +func (h *testHarness) closeChannel(id uint64, height uint32) { + h.t.Helper() + + h.mu.Lock() + defer h.mu.Unlock() + + chanID := chanIDFromInt(id) + + _, ok := h.channels[chanID] + require.Truef(h.t, ok, "unable to fetch channel %d", id) + + h.closedChannels[chanID] = height + delete(h.channels, chanID) + + chanPointHash, err := chainhash.NewHash(chanID[:]) + require.NoError(h.t, err) + + if !h.sendUpdatesOn { + return + } + + h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{ + CloseSummary: &channeldb.ChannelCloseSummary{ + ChanPoint: wire.OutPoint{ + Hash: *chanPointHash, + Index: 0, + }, + CloseHeight: height, + }, + }) +} + // registerChannel registers the channel identified by id with the client. func (h *testHarness) registerChannel(id uint64) { h.t.Helper() @@ -624,7 +702,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { err := h.client.BackupState( &chanID, retribution, channeldb.SingleFunderBit, ) - require.ErrorIs(h.t, expErr, err) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -770,11 +848,132 @@ func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { require.NoError(h.t, err) } +// relevantSessions returns a list of session IDs that have acked updates for +// the given channel ID. +func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID { + h.t.Helper() + + var ( + sessionIDs []wtdb.SessionID + cID = chanIDFromInt(chanID) + ) + + collectSessions := wtdb.WithPerNumAckedUpdates( + func(session *wtdb.ClientSession, id lnwire.ChannelID, + _ uint16) { + + if !bytes.Equal(id[:], cID[:]) { + return + } + + sessionIDs = append(sessionIDs, session.ID) + }, + ) + + _, err := h.clientDB.ListClientSessions(nil, nil, collectSessions) + require.NoError(h.t, err) + + return sessionIDs +} + +// isSessionClosable returns true if the given session has been marked as +// closable in the DB. +func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool { + h.t.Helper() + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + _, ok := cs[id] + + return ok +} + +// mockSubscription is a mock subscription client that blocks on sends into the +// updates channel. +type mockSubscription struct { + t *testing.T + updates chan interface{} + + // Embed the subscription interface in this mock so that we satisfy it. + subscribe.Subscription +} + +// newMockSubscription creates a mock subscription. +func newMockSubscription(t *testing.T) *mockSubscription { + t.Helper() + + return &mockSubscription{ + t: t, + updates: make(chan interface{}), + } +} + +// sendUpdate sends an update into our updates channel, mocking the dispatch of +// an update from a subscription server. This call will fail the test if the +// update is not consumed within our timeout. +func (m *mockSubscription) sendUpdate(update interface{}) { + select { + case m.updates <- update: + + case <-time.After(waitTime): + m.t.Fatalf("update: %v timeout", update) + } +} + +// Updates returns the updates channel for the mock. +func (m *mockSubscription) Updates() <-chan interface{} { + return m.updates +} + +// mockBlockSub mocks out the ChainNotifier. +type mockBlockSub struct { + t *testing.T + events chan *chainntnfs.BlockEpoch + + chainntnfs.ChainNotifier +} + +// newMockBlockSub creates a new mockBlockSub. +func newMockBlockSub(t *testing.T) *mockBlockSub { + t.Helper() + + return &mockBlockSub{ + t: t, + events: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn returns a channel that can be used to listen for new +// blocks. +func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.events, + }, nil +} + +// sendNewBlock will send a new block on the notification channel. +func (m *mockBlockSub) sendNewBlock(height int32) { + select { + case m.events <- &chainntnfs.BlockEpoch{Height: height}: + + case <-time.After(waitTime): + m.t.Fatalf("timed out sending block: %d", height) + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) ) +var defaultTxPolicy = wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, +} + type clientTest struct { name string cfg harnessCfg @@ -791,10 +990,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, noRegisterChan0: true, @@ -825,10 +1021,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -860,10 +1053,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -927,10 +1117,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -1006,10 +1193,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1062,10 +1246,7 @@ var clientTests = []clientTest{ localBalance: 100000001, // ensure (% amt != 0) remoteBalance: 200000001, // ensure (% amt != 0) policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 1000, }, }, @@ -1106,10 +1287,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1156,10 +1334,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1212,10 +1387,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1274,10 +1446,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 10, }, }, @@ -1333,10 +1502,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1381,10 +1547,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1489,10 +1652,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1557,10 +1717,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noServerStart: true, @@ -1654,6 +1811,209 @@ var clientTests = []clientTest{ }, waitTime) require.NoError(h.t, err) }, + }, { + name: "assert that sessions are correctly marked as closable", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const numUpdates = 5 + + // In this test we assert that a channel is correctly + // marked as closed and that sessions are also correctly + // marked as closable. + + // We start with the sendUpdatesOn parameter set to + // false so that we can test that channels are correctly + // evaluated at startup. + h.sendUpdatesOn = false + + // Advance channel 0 to create all states and back them + // all up. This will saturate the session with updates + // for channel 0 which means that the session should be + // considered closable when channel 0 is closed. + hints := h.advanceChannelN(0, numUpdates) + h.backupStates(0, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // We expect only 1 session to have updates for this + // channel. + sessionIDs := h.relevantSessions(0) + require.Len(h.t, sessionIDs, 1) + + // Since channel 0 is still open, the session should not + // yet be closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Close the channel. + h.closeChannel(0, 1) + + // Since updates are currently not being sent, we expect + // the session to still not be marked as closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Restart the client. + h.client.ForceQuit() + h.startClient() + + // The session should now have been marked as closable. + err := wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we set sendUpdatesOn to true and do the same with + // a new channel. A restart should now not be necessary + // anymore. + h.sendUpdatesOn = true + + h.makeChannel( + 1, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(1) + + hints = h.advanceChannelN(1, numUpdates) + h.backupStates(1, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(1) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Assert that the session is not yet closable since + // the channel is still open. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Now close the channel. + h.closeChannel(1, 1) + + // Since the updates have been turned on, the session + // should now show up as closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we test that a session must be exhausted with all + // channels closed before it is seen as closable. + h.makeChannel( + 2, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(2) + + // Fill up only half of the session updates. + hints = h.advanceChannelN(2, numUpdates) + h.backupStates(2, 0, numUpdates/2, nil) + h.waitServerUpdates(hints[:numUpdates/2], waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(2) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Now close the channel. + h.closeChannel(2, 1) + + // The session should _not_ be closable due to it not + // being exhausted yet. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Create a new channel. + h.makeChannel( + 3, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(3) + + hints = h.advanceChannelN(3, numUpdates) + h.backupStates(3, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Close it. + h.closeChannel(3, 1) + + // Now the session should be closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we will mine a few blocks. This will cause the + // necessary session-close-range to be exceeded meaning + // that the client should send the DeleteSession message + // to the server. We will assert that both the client + // and server have deleted the appropriate sessions and + // channel info. + + // Before we mine blocks, assert that the client + // currently has 3 closable sessions. + closableSess, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + require.Len(h.t, closableSess, 3) + + // Assert that the server is also aware of all of these + // sessions. + for sid := range closableSess { + _, err := h.serverDB.GetSessionInfo(&sid) + require.NoError(h.t, err) + } + + // Also make a note of the total number of sessions the + // client has. + sessions, err := h.clientDB.ListClientSessions(nil, nil) + require.NoError(h.t, err) + require.Len(h.t, sessions, 4) + + h.mine(3) + + // The client should no longer have any closable + // sessions and the total list of client sessions should + // no longer include the three that it previously had + // marked as closable. The server should also no longer + // have these sessions in its DB. + err = wait.Predicate(func() bool { + sess, err := h.clientDB.ListClientSessions( + nil, nil, + ) + require.NoError(h.t, err) + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + if len(sess) != 1 || len(cs) != 0 { + return false + } + + for sid := range closableSess { + _, ok := sess[sid] + if ok { + return false + } + + _, err := h.serverDB.GetSessionInfo( + &sid, + ) + if !errors.Is( + err, wtdb.ErrSessionNotFound, + ) { + return false + } + } + + return true + + }, waitTime) + require.NoError(h.t, err) + }, }, } diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index f496074bf..c6884bb35 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -1,6 +1,8 @@ package wtclient -import "errors" +import ( + "errors" +) var ( // ErrClientExiting signals that the watchtower client is shutting down. @@ -11,6 +13,10 @@ var ( ErrTowerCandidatesExhausted = errors.New("exhausted all tower " + "candidates") + // ErrTowerNotInIterator is returned when a requested tower was not + // found in the iterator. + ErrTowerNotInIterator = errors.New("tower not in iterator") + // ErrPermanentTowerFailure signals that the tower has reported that it // has permanently failed or the client believes this has happened based // on the tower's behavior. diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 9751988ba..4eebef4e5 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -64,6 +64,11 @@ type DB interface { ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) + // GetClientSession loads the ClientSession with the given ID from the + // DB. + GetClientSession(wtdb.SessionID, + ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) + // FetchSessionCommittedUpdates retrieves the current set of un-acked // updates of the given session. FetchSessionCommittedUpdates(id *wtdb.SessionID) ( @@ -78,9 +83,29 @@ type DB interface { NumAckedUpdates(id *wtdb.SessionID) (uint64, error) // FetchChanSummaries loads a mapping from all registered channels to - // their channel summaries. + // their channel summaries. Only the channels that have not yet been + // marked as closed will be loaded. FetchChanSummaries() (wtdb.ChannelSummaries, error) + // MarkChannelClosed will mark a registered channel as closed by setting + // its closed-height as the given block height. It returns a list of + // session IDs for sessions that are now considered closable due to the + // close of this channel. The details for this channel will be deleted + // from the DB if there are no more sessions in the DB that contain + // updates for this channel. + MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) ( + []wtdb.SessionID, error) + + // ListClosableSessions fetches and returns the IDs for all sessions + // marked as closable. + ListClosableSessions() (map[wtdb.SessionID]uint32, error) + + // DeleteSession can be called when a session should be deleted from the + // DB. All references to the session will also be deleted from the DB. + // A session will only be deleted if it was previously marked as + // closable. + DeleteSession(id wtdb.SessionID) error + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the @@ -174,3 +199,30 @@ type ClientSession struct { // key used to connect to the watchtower. SessionKeyECDH keychain.SingleKeyECDH } + +// NewClientSessionFromDBSession converts a wtdb.ClientSession to a +// ClientSession. +func NewClientSessionFromDBSession(s *wtdb.ClientSession, tower *Tower, + keyRing ECDHKeyRing) (*ClientSession, error) { + + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + + sessionKeyECDH := keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + return &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + }, nil +} diff --git a/watchtower/wtclient/sess_close_min_heap.go b/watchtower/wtclient/sess_close_min_heap.go new file mode 100644 index 000000000..c5f58ec1a --- /dev/null +++ b/watchtower/wtclient/sess_close_min_heap.go @@ -0,0 +1,95 @@ +package wtclient + +import ( + "sync" + + "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// sessionCloseMinHeap is a thread-safe min-heap implementation that stores +// sessionCloseItem items and prioritises the item with the lowest block height. +type sessionCloseMinHeap struct { + queue queue.PriorityQueue + mu sync.Mutex +} + +// newSessionCloseMinHeap constructs a new sessionCloseMineHeap. +func newSessionCloseMinHeap() *sessionCloseMinHeap { + return &sessionCloseMinHeap{} +} + +// Len returns the length of the queue. +func (h *sessionCloseMinHeap) Len() int { + h.mu.Lock() + defer h.mu.Unlock() + + return h.queue.Len() +} + +// Empty returns true if the queue is empty. +func (h *sessionCloseMinHeap) Empty() bool { + h.mu.Lock() + defer h.mu.Unlock() + + return h.queue.Empty() +} + +// Push adds an item to the priority queue. +func (h *sessionCloseMinHeap) Push(item *sessionCloseItem) { + h.mu.Lock() + defer h.mu.Unlock() + + h.queue.Push(item) +} + +// Pop removes the top most item from the queue. +func (h *sessionCloseMinHeap) Pop() *sessionCloseItem { + h.mu.Lock() + defer h.mu.Unlock() + + if h.queue.Empty() { + return nil + } + + item := h.queue.Pop() + + return item.(*sessionCloseItem) //nolint:forcetypeassert +} + +// Top returns the top most item from the queue without removing it. +func (h *sessionCloseMinHeap) Top() *sessionCloseItem { + h.mu.Lock() + defer h.mu.Unlock() + + if h.queue.Empty() { + return nil + } + + item := h.queue.Top() + + return item.(*sessionCloseItem) //nolint:forcetypeassert +} + +// sessionCloseItem represents a session that is ready to be deleted. +type sessionCloseItem struct { + // sessionID is the ID of the session in question. + sessionID wtdb.SessionID + + // deleteHeight is the block height after which we can delete the + // session. + deleteHeight uint32 +} + +// Less returns true if the current item's delete height is less than the +// other sessionCloseItem's delete height. This results in lower block heights +// being popped first from the heap. +// +// NOTE: this is part of the queue.PriorityQueueItem interface. +func (s *sessionCloseItem) Less(other queue.PriorityQueueItem) bool { + o := other.(*sessionCloseItem).deleteHeight //nolint:forcetypeassert + + return s.deleteHeight < o +} + +var _ queue.PriorityQueueItem = (*sessionCloseItem)(nil) diff --git a/watchtower/wtclient/sess_close_min_heap_test.go b/watchtower/wtclient/sess_close_min_heap_test.go new file mode 100644 index 000000000..9983f5b93 --- /dev/null +++ b/watchtower/wtclient/sess_close_min_heap_test.go @@ -0,0 +1,52 @@ +package wtclient + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSessionCloseMinHeap asserts that the sessionCloseMinHeap behaves as +// expected. +func TestSessionCloseMinHeap(t *testing.T) { + t.Parallel() + + heap := newSessionCloseMinHeap() + require.Nil(t, heap.Pop()) + require.Nil(t, heap.Top()) + require.True(t, heap.Empty()) + require.Zero(t, heap.Len()) + + // Add an item with height 10. + item1 := &sessionCloseItem{ + sessionID: [33]byte{1, 2, 3}, + deleteHeight: 10, + } + + heap.Push(item1) + require.Equal(t, item1, heap.Top()) + require.False(t, heap.Empty()) + require.EqualValues(t, 1, heap.Len()) + + // Add a bunch more items with heights 1, 2, 6, 11, 6, 30, 9. + heap.Push(&sessionCloseItem{deleteHeight: 1}) + heap.Push(&sessionCloseItem{deleteHeight: 2}) + heap.Push(&sessionCloseItem{deleteHeight: 6}) + heap.Push(&sessionCloseItem{deleteHeight: 11}) + heap.Push(&sessionCloseItem{deleteHeight: 6}) + heap.Push(&sessionCloseItem{deleteHeight: 30}) + heap.Push(&sessionCloseItem{deleteHeight: 9}) + + // Now pop from the queue and assert that the items are returned in + // ascending order. + require.EqualValues(t, 1, heap.Pop().deleteHeight) + require.EqualValues(t, 2, heap.Pop().deleteHeight) + require.EqualValues(t, 6, heap.Pop().deleteHeight) + require.EqualValues(t, 6, heap.Pop().deleteHeight) + require.EqualValues(t, 9, heap.Pop().deleteHeight) + require.EqualValues(t, 10, heap.Pop().deleteHeight) + require.EqualValues(t, 11, heap.Pop().deleteHeight) + require.EqualValues(t, 30, heap.Pop().deleteHeight) + require.Nil(t, heap.Pop()) + require.Zero(t, heap.Len()) +} diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 3cbe0c6f1..53c643b6f 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -23,22 +23,39 @@ var ( // cChanDetailsBkt is a top-level bucket storing: // channel-id => cChannelSummary -> encoded ClientChanSummary. // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height cChanDetailsBkt = []byte("client-channel-detail-bucket") + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // db-session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + // cChanDBID is a key used in the cChanDetailsBkt to store the // db-assigned-id of a channel. cChanDBID = []byte("client-channel-db-id") + // cChanClosedHeight is a key used in the cChanDetailsBkt to store the + // block height at which the channel's closing transaction was mined in. + // If this there is no associated value for this key, then the channel + // has not yet been marked as closed. + cChanClosedHeight = []byte("client-channel-closed-height") + // cChannelSummary is a key used in cChanDetailsBkt to store the encoded // body of ClientChanSummary. cChannelSummary = []byte("client-channel-summary") // cSessionBkt is a top-level bucket storing: // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id // => cSessionCommits => seqnum -> encoded CommittedUpdate // => cSessionAckRangeIndex => db-chan-id => start -> end cSessionBkt = []byte("client-session-bucket") + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-id of a session. + cSessionDBID = []byte("client-session-db-id") + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of // the ClientSession. cSessionBody = []byte("client-session-body") @@ -55,6 +72,10 @@ var ( // db-assigned-id -> channel-ID cChanIDIndexBkt = []byte("client-channel-id-index") + // cSessionIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> session-id + cSessionIDIndexBkt = []byte("client-session-id-index") + // cTowerBkt is a top-level bucket storing: // tower-id -> encoded Tower. cTowerBkt = []byte("client-tower-bucket") @@ -69,6 +90,10 @@ var ( "client-tower-to-session-index-bucket", ) + // cClosableSessionsBkt is a top-level bucket storing: + // db-session-id -> last-channel-close-height + cClosableSessionsBkt = []byte("client-closable-sessions-bucket") + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -142,6 +167,23 @@ var ( // ErrSessionFailedFilterFn indicates that a particular session did // not pass the filter func provided by the caller. ErrSessionFailedFilterFn = errors.New("session failed filter func") + + // ErrSessionNotClosable is returned when a session is not found in the + // closable list. + ErrSessionNotClosable = errors.New("session is not closable") + + // errSessionHasOpenChannels is an error used to indicate that a + // session has updates for channels that are still open. + errSessionHasOpenChannels = errors.New("session has open channels") + + // errSessionHasUnackedUpdates is an error used to indicate that a + // session has un-acked updates. + errSessionHasUnackedUpdates = errors.New("session has un-acked updates") + + // errChannelHasMoreSessions is an error used to indicate that a channel + // has updates in other non-closed sessions. + errChannelHasMoreSessions = errors.New("channel has updates in " + + "other sessions") ) // NewBoltBackendCreator returns a function that creates a new bbolt backend for @@ -241,6 +283,8 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cTowerIndexBkt, cTowerToSessionIndexBkt, cChanIDIndexBkt, + cSessionIDIndexBkt, + cClosableSessionsBkt, } for _, bucket := range buckets { @@ -723,24 +767,58 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { } } - // Add the new entry to the towerID-to-SessionID index. - indexBkt := towerToSessionIndex.NestedReadWriteBucket( - towerID.Bytes(), - ) - if indexBkt == nil { - return ErrTowerNotFound + // Get the session-ID index bucket. + dbIDIndex := tx.ReadWriteBucket(cSessionIDIndexBkt) + if dbIDIndex == nil { + return ErrUninitializedDB } - err = indexBkt.Put(session.ID[:], []byte{1}) + // Get a new, unique, ID for this session from the session-ID + // index bucket. + nextSeq, err := dbIDIndex.NextSequence() if err != nil { return err } + // Add the new entry to the dbID-to-SessionID index. + newIndex, err := writeBigSize(nextSeq) + if err != nil { + return err + } + + err = dbIDIndex.Put(newIndex, session.ID[:]) + if err != nil { + return err + } + + // Also add the db-assigned-id to the session bucket under the + // cSessionDBID key. sessionBkt, err := sessions.CreateBucket(session.ID[:]) if err != nil { return err } + err = sessionBkt.Put(cSessionDBID, newIndex) + if err != nil { + return err + } + + // TODO(elle): migrate the towerID-to-SessionID to use the + // new db-assigned sessionID's rather. + + // Add the new entry to the towerID-to-SessionID index. + towerSessions := towerToSessionIndex.NestedReadWriteBucket( + towerID.Bytes(), + ) + if towerSessions == nil { + return ErrTowerNotFound + } + + err = towerSessions.Put(session.ID[:], []byte{1}) + if err != nil { + return err + } + // Finally, write the client session's body in the sessions // bucket. return putClientSessionBody(sessionBkt, session) @@ -960,6 +1038,37 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, return byteOrder.Uint32(keyIndexBytes), nil } +// GetClientSession loads the ClientSession with the given ID from the DB. +func (c *ClientDB) GetClientSession(id SessionID, + opts ...ClientSessionListOption) (*ClientSession, error) { + + var sess *ClientSession + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + session, err := c.getClientSession( + sessionsBkt, chanIDIndexBkt, id[:], nil, opts..., + ) + if err != nil { + return err + } + + sess = session + + return nil + }, func() {}) + + return sess, err +} + // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. @@ -974,20 +1083,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID, return ErrUninitializedDB } - towers := tx.ReadBucket(cTowerBkt) - if towers == nil { - return ErrUninitializedDB - } - chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) if chanIDIndexBkt == nil { return ErrUninitializedDB } - var err error - // If no tower ID is specified, then fetch all the sessions // known to the db. + var err error if id == nil { clientSessions, err = c.listClientAllSessions( sessions, chanIDIndexBkt, filterFn, opts..., @@ -1181,7 +1284,8 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { } // FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. +// channel summaries. Only the channels that have not yet been marked as closed +// will be loaded. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { var summaries map[lnwire.ChannelID]ClientChanSummary @@ -1197,6 +1301,13 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { return ErrCorruptChanDetails } + // If this channel has already been marked as closed, + // then its summary does not need to be loaded. + closedHeight := chanDetails.Get(cChanClosedHeight) + if len(closedHeight) > 0 { + return nil + } + var chanID lnwire.ChannelID copy(chanID[:], k) @@ -1292,6 +1403,420 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, return nil } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) { + sessions := make(map[SessionID]uint32) + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + csBkt := tx.ReadBucket(cClosableSessionsBkt) + if csBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + return csBkt.ForEach(func(dbIDBytes, heightBytes []byte) error { + dbID, err := readBigSize(dbIDBytes) + if err != nil { + return err + } + + sessID, err := getRealSessionID(sessIDIndexBkt, dbID) + if err != nil { + return err + } + + sessions[*sessID] = byteOrder.Uint32(heightBytes) + + return nil + }) + }, func() { + sessions = make(map[SessionID]uint32) + }) + if err != nil { + return nil, err + } + + return sessions, nil +} + +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if was previously marked as closable. +func (c *ClientDB) DeleteSession(id SessionID) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + closableBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadWriteBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadWriteBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + towerToSessBkt := tx.ReadWriteBucket(cTowerToSessionIndexBkt) + if towerToSessBkt == nil { + return ErrUninitializedDB + } + + // Get the sub-bucket for this session ID. If it does not exist + // then the session has already been deleted and so our work is + // done. + sessionBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessionBkt == nil { + return nil + } + + _, dbIDBytes, err := getDBSessionID(sessionsBkt, id) + if err != nil { + return err + } + + // First we check if the session has actually been marked as + // closable. + if closableBkt.Get(dbIDBytes) == nil { + return ErrSessionNotClosable + } + + sess, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return err + } + + // Delete from the tower-to-sessionID index. + towerIndexBkt := towerToSessBkt.NestedReadWriteBucket( + sess.TowerID.Bytes(), + ) + if towerIndexBkt == nil { + return fmt.Errorf("no entry in the tower-to-session "+ + "index found for tower ID %v", sess.TowerID) + } + + err = towerIndexBkt.Delete(id[:]) + if err != nil { + return err + } + + // Delete entry from session ID index. + err = sessIDIndexBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Delete the entry from the closable sessions index. + err = closableBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Get the acked updates range index for the session. This is + // used to get the list of channels that the session has updates + // for. + ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackRanges == nil { + // A session would only be considered closable if it + // was exhausted. Meaning that it should not be the + // case that it has no acked-updates. + return fmt.Errorf("cannot delete session %s since it "+ + "is not yet exhausted", id) + } + + // For each of the channels, delete the session ID entry. + err = ackRanges.ForEach(func(chanDBID, _ []byte) error { + chanDBIDInt, err := readBigSize(chanDBID) + if err != nil { + return err + } + + chanID, err := getRealChannelID( + chanIDIndexBkt, chanDBIDInt, + ) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + chanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + chanSessions := chanDetails.NestedReadWriteBucket( + cChanSessions, + ) + if chanSessions == nil { + return fmt.Errorf("no session list found for "+ + "channel %s", chanID) + } + + // Check that this session was actually listed in the + // session list for this channel. + if len(chanSessions.Get(dbIDBytes)) == 0 { + return fmt.Errorf("session %s not found in "+ + "the session list for channel %s", id, + chanID) + } + + // If it was, then delete it. + err = chanSessions.Delete(dbIDBytes) + if err != nil { + return err + } + + // If this was the last session for this channel, we can + // now delete the channel details for this channel + // completely. + err = chanSessions.ForEach(func(_, _ []byte) error { + return errChannelHasMoreSessions + }) + if errors.Is(err, errChannelHasMoreSessions) { + return nil + } else if err != nil { + return err + } + + // Delete the channel's entry from the channel-id-index. + dbID := chanDetails.Get(cChanDBID) + err = chanIDIndexBkt.Delete(dbID) + if err != nil { + return err + } + + // Delete the channel details. + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + }) + if err != nil { + return err + } + + // Delete the actual session. + return sessionsBkt.DeleteNestedBucket(id[:]) + }, func() {}) +} + +// MarkChannelClosed will mark a registered channel as closed by setting its +// closed-height as the given block height. It returns a list of session IDs for +// sessions that are now considered closable due to the close of this channel. +// The details for this channel will be deleted from the DB if there are no more +// sessions in the DB that contain updates for this channel. +func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]SessionID, error) { + + var closableSessions []SessionID + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableSessBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel + // details can be deleted. + chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions) + if chanSessIDsBkt == nil { + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + } + + // Otherwise, mark the channel as closed. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + + err := chanDetails.Put(cChanClosedHeight, height[:]) + if err != nil { + return err + } + + // Now iterate through all the sessions of the channel to check + // if any of them are closeable. + return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error { + sessDBIDInt, err := readBigSize(sessDBID) + if err != nil { + return err + } + + // Use the session-ID index to get the real session ID. + sID, err := getRealSessionID( + sessIDIndexBkt, sessDBIDInt, + ) + if err != nil { + return err + } + + isClosable, err := isSessionClosable( + sessionsBkt, chanDetailsBkt, chanIDIndexBkt, + sID, + ) + if err != nil { + return err + } + + if !isClosable { + return nil + } + + // Add session to "closableSessions" list and add the + // block height that this last channel was closed in. + // This will be used in future to determine when we + // should delete the session. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + err = closableSessBkt.Put(sessDBID, height[:]) + if err != nil { + return err + } + + closableSessions = append(closableSessions, *sID) + + return nil + }) + }, func() { + closableSessions = nil + }) + if err != nil { + return nil, err + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if all the following points are true: +// 1) It has no un-acked updates. +// 2) It is exhausted (ie it can't accept any more updates) +// 3) All the channels that it has acked updates for are closed. +func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket, + id *SessionID) (bool, error) { + + sessBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessBkt == nil { + return false, ErrSessionNotFound + } + + commitsBkt := sessBkt.NestedReadBucket(cSessionCommits) + if commitsBkt == nil { + // If the session has no cSessionCommits bucket then we can be + // sure that no updates have ever been committed to the session + // and so it is not yet exhausted. + return false, nil + } + + // If the session has any un-acked updates, then it is not yet closable. + err := commitsBkt.ForEach(func(_, _ []byte) error { + return errSessionHasUnackedUpdates + }) + if errors.Is(err, errSessionHasUnackedUpdates) { + return false, nil + } else if err != nil { + return false, err + } + + session, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return false, err + } + + // We have already checked that the session has no more committed + // updates. So now we can check if the session is exhausted. + if session.SeqNum < session.Policy.MaxUpdates { + // If the session is not yet exhausted, it is not yet closable. + return false, nil + } + + // If the session has no acked-updates, then something is wrong since + // the above check ensures that this session has been exhausted meaning + // that it should have MaxUpdates acked updates. + ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackedRangeBkt == nil { + return false, fmt.Errorf("no acked-updates found for "+ + "exhausted session %s", id) + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error { + dbChanIDInt, err := readBigSize(dbChanID) + if err != nil { + return err + } + + chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt) + if err != nil { + return err + } + + // Get the channel details bucket for the channel. + chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:]) + if chanDetails == nil { + return fmt.Errorf("no channel details found for "+ + "channel %s referenced by session %s", chanID, + id) + } + + // If a closed height has been set, then the channel is closed. + closedHeight := chanDetails.Get(cChanClosedHeight) + if len(closedHeight) > 0 { + return nil + } + + // Otherwise, the channel is not yet closed meaning that the + // session is not yet closable. We break the ForEach by + // returning an error to indicate this. + return errSessionHasOpenChannels + }) + if errors.Is(err, errSessionHasOpenChannels) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (c *ClientDB) CommitUpdate(id *SessionID, @@ -1410,7 +1935,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return ErrUninitializedDB } - chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) if chanDetailsBkt == nil { return ErrUninitializedDB } @@ -1494,6 +2019,23 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } + dbSessionID, _, err := getDBSessionID(sessions, *id) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + committedUpdate.BackupID.ChanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + err = putChannelToSessionMapping(chanDetails, dbSessionID) + if err != nil { + return err + } + // Get the range index for the given session-channel pair. index, err := c.getRangeIndex(tx, *id, chanID) if err != nil { @@ -1504,6 +2046,26 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, }, func() {}) } +// putChannelToSessionMapping adds the given session ID to a channel's +// cChanSessions bucket. +func putChannelToSessionMapping(chanDetails kvdb.RwBucket, + dbSessID uint64) error { + + chanSessIDsBkt, err := chanDetails.CreateBucketIfNotExists( + cChanSessions, + ) + if err != nil { + return err + } + + b, err := writeBigSize(dbSessID) + if err != nil { + return err + } + + return chanSessIDsBkt.Put(b, []byte{1}) +} + // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize // the CommittedUpdates, AckUpdates or the Tower associated with the session. @@ -1882,6 +2444,68 @@ func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID lnwire.ChannelID) (uint64, return id, idBytes, nil } +// getDBSessionID returns the db-assigned session ID for the given real session +// ID. It returns both the uint64 and byte representation. +func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, + []byte, error) { + + sessionBkt := sessionsBkt.NestedReadBucket(sessionID[:]) + if sessionBkt == nil { + return 0, nil, ErrClientSessionNotFound + } + + idBytes := sessionBkt.Get(cSessionDBID) + if len(idBytes) == 0 { + return 0, nil, fmt.Errorf("no db-assigned ID found for "+ + "session ID %s", sessionID) + } + + id, err := readBigSize(idBytes) + if err != nil { + return 0, nil, err + } + + return id, idBytes, nil +} + +func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID, + error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + sessIDBytes := sessIDIndexBkt.Get(dbIDBytes) + if len(sessIDBytes) != SessionIDSize { + return nil, fmt.Errorf("session ID not found") + } + + var sessID SessionID + copy(sessID[:], sessIDBytes) + + return &sessID, nil +} + +func getRealChannelID(chanIDIndexBkt kvdb.RBucket, + dbID uint64) (*lnwire.ChannelID, error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + chanIDBytes := chanIDIndexBkt.Get(dbIDBytes) + if len(chanIDBytes) != 32 { //nolint:gomnd + return nil, fmt.Errorf("channel ID not found") + } + + var chanIDS lnwire.ChannelID + copy(chanIDS[:], chanIDBytes) + + return &chanIDS, nil +} + // writeBigSize will encode the given uint64 as a BigSize byte slice. func writeBigSize(i uint64) ([]byte, error) { var b bytes.Buffer diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index cd77ec77e..b3d241175 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -3,6 +3,7 @@ package wtdb_test import ( crand "crypto/rand" "io" + "math/rand" "net" "testing" @@ -17,6 +18,8 @@ import ( "github.com/stretchr/testify/require" ) +const blobType = blob.TypeAltruistCommit + // pseudoAddr is a fake network address to be used for testing purposes. var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} @@ -193,6 +196,35 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, + blockHeight uint32, expErr error) []wtdb.SessionID { + + h.t.Helper() + + closableSessions, err := h.db.MarkChannelClosed(id, blockHeight) + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + +func (h *clientDBHarness) listClosableSessions( + expErr error) map[wtdb.SessionID]uint32 { + + h.t.Helper() + + closableSessions, err := h.db.ListClosableSessions() + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + +func (h *clientDBHarness) deleteSession(id wtdb.SessionID, expErr error) { + h.t.Helper() + + err := h.db.DeleteSession(id) + require.ErrorIs(h.t, err, expErr) +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -605,6 +637,118 @@ func testCommitUpdate(h *clientDBHarness) { }, nil) } +// testMarkChannelClosed asserts the behaviour of MarkChannelClosed. +func testMarkChannelClosed(h *clientDBHarness) { + tower := h.newTower() + + // Create channel 1. + chanID1 := randChannelID(h.t) + + // Since we have not yet registered the channel, we expect an error + // when attempting to mark it as closed. + h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered) + + // Now register the channel. + h.registerChan(chanID1, nil, nil) + + // Since there are still no sessions that would have updates for the + // channel, marking it as closed now should succeed. + h.markChannelClosed(chanID1, 1, nil) + + // Register channel 2. + chanID2 := randChannelID(h.t) + h.registerChan(chanID2, nil, nil) + + // Create session1 with MaxUpdates set to 5. + session1 := h.randSession(h.t, tower.ID, 5) + h.insertSession(session1, nil) + + // Add an update for channel 2 in session 1 and ack it too. + update := randCommittedUpdateForChannel(h.t, chanID2, 1) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + require.Zero(h.t, lastApplied) + h.ackUpdate(&session1.ID, 1, 1, nil) + + // Marking channel 2 now should not result in any closable sessions + // since session 1 is not yet exhausted. + sl := h.markChannelClosed(chanID2, 1, nil) + require.Empty(h.t, sl) + + // Create channel 3 and 4. + chanID3 := randChannelID(h.t) + h.registerChan(chanID3, nil, nil) + + chanID4 := randChannelID(h.t) + h.registerChan(chanID4, nil, nil) + + // Add an update for channel 4 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID4, 2) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 1, lastApplied) + h.ackUpdate(&session1.ID, 2, 2, nil) + + // Add an update for channel 3 in session 1. But dont ack it yet. + update = randCommittedUpdateForChannel(h.t, chanID2, 3) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 2, lastApplied) + + // Mark channel 4 as closed & assert that session 1 is not seen as + // closable since it still has committed updates. + sl = h.markChannelClosed(chanID4, 1, nil) + require.Empty(h.t, sl) + + // Now ack the update we added above. + h.ackUpdate(&session1.ID, 3, 3, nil) + + // Mark channel 3 as closed & assert that session 1 is still not seen as + // closable since it is not yet exhausted. + sl = h.markChannelClosed(chanID3, 1, nil) + require.Empty(h.t, sl) + + // Create channel 5 and 6. + chanID5 := randChannelID(h.t) + h.registerChan(chanID5, nil, nil) + + chanID6 := randChannelID(h.t) + h.registerChan(chanID6, nil, nil) + + // Add an update for channel 5 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID5, 4) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 3, lastApplied) + h.ackUpdate(&session1.ID, 4, 4, nil) + + // Add an update for channel 6 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID6, 5) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 4, lastApplied) + h.ackUpdate(&session1.ID, 5, 5, nil) + + // The session is no exhausted. + // If we now close channel 5, session 1 should still not be closable + // since it has an update for channel 6 which is still open. + sl = h.markChannelClosed(chanID5, 1, nil) + require.Empty(h.t, sl) + require.Empty(h.t, h.listClosableSessions(nil)) + + // Also check that attempting to delete the session will fail since it + // is not yet considered closable. + h.deleteSession(session1.ID, wtdb.ErrSessionNotClosable) + + // Finally, if we close channel 6, session 1 _should_ be in the closable + // list. + sl = h.markChannelClosed(chanID6, 100, nil) + require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID}) + slMap := h.listClosableSessions(nil) + require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{ + session1.ID: 100, + }, 0) + + // Assert that we now can delete the session. + h.deleteSession(session1.ID, nil) + require.Empty(h.t, h.listClosableSessions(nil)) +} + // testAckUpdate asserts the behavior of AckUpdate. func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit @@ -821,6 +965,10 @@ func TestClientDB(t *testing.T) { name: "ack update", run: testAckUpdate, }, + { + name: "mark channel closed", + run: testMarkChannelClosed, + }, } for _, database := range dbs { @@ -841,12 +989,32 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { + t.Helper() + + chanID := randChannelID(t) + + return randCommittedUpdateForChannel(t, chanID, seqNum) +} + +func randChannelID(t *testing.T) lnwire.ChannelID { + t.Helper() + var chanID lnwire.ChannelID _, err := io.ReadFull(crand.Reader, chanID[:]) require.NoError(t, err) + return chanID +} + +// randCommittedUpdateForChannel generates a random committed update for the +// given channel ID. +func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID, + seqNum uint16) *wtdb.CommittedUpdate { + + t.Helper() + var hint blob.BreachHint - _, err = io.ReadFull(crand.Reader, hint[:]) + _, err := io.ReadFull(crand.Reader, hint[:]) require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) @@ -865,3 +1033,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { }, } } + +func (h *clientDBHarness) randSession(t *testing.T, + towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession { + + t.Helper() + + var id wtdb.SessionID + rand.Read(id[:]) + + return &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: towerID, + Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, + MaxUpdates: maxUpdates, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + KeyIndex: h.nextKeyIndex(towerID, blobType), + }, + ID: id, + } +} diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index f7952ba4e..639030631 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -8,6 +8,8 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // log is a logger that is initialized with no output filters. This @@ -36,6 +38,8 @@ func UseLogger(logger btclog.Logger) { migration3.UseLogger(logger) migration4.UseLogger(logger) migration5.UseLogger(logger) + migration6.UseLogger(logger) + migration7.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration6/client_db.go b/watchtower/wtdb/migration6/client_db.go new file mode 100644 index 000000000..8d5ffbc29 --- /dev/null +++ b/watchtower/wtdb/migration6/client_db.go @@ -0,0 +1,114 @@ +package migration6 + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-id of a session. + cSessionDBID = []byte("client-session-db-id") + + // cSessionIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> session-id + cSessionIDIndexBkt = []byte("client-session-id-index") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + byteOrder = binary.BigEndian +) + +// MigrateSessionIDIndex adds a new session ID index to the tower client db. +// This index is a mapping from db-assigned ID (a uint64 encoded using BigSize) +// to real session ID (33 bytes). This mapping will allow us to persist session +// pointers with fewer bytes in the future. +func MigrateSessionIDIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client db to add a new session ID " + + "index which stores a mapping from db-assigned ID to real " + + "session ID") + + // Create a new top-level bucket for the index. + indexBkt, err := tx.CreateTopLevelBucket(cSessionIDIndexBkt) + if err != nil { + return err + } + + // Get the existing top-level sessions bucket. + sessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + // Iterate over the sessions bucket where each key is a session-ID. + return sessionsBkt.ForEach(func(sessionID, _ []byte) error { + // Ask the DB for a new, unique, id for the index bucket. + nextSeq, err := indexBkt.NextSequence() + if err != nil { + return err + } + + newIndex, err := writeBigSize(nextSeq) + if err != nil { + return err + } + + // Add the new db-assigned-ID to real-session-ID pair to the + // new index bucket. + err = indexBkt.Put(newIndex, sessionID) + if err != nil { + return err + } + + // Get the sub-bucket for this specific session ID. + sessionBkt := sessionsBkt.NestedReadWriteBucket(sessionID) + if sessionBkt == nil { + return ErrCorruptClientSession + } + + // Here we ensure that the session bucket includes a session + // body. The only reason we do this is so that we can simulate + // a migration fail in a test to ensure that a migration fail + // results in an untouched db. + sessionBodyBytes := sessionBkt.Get(cSessionBody) + if sessionBodyBytes == nil { + return ErrCorruptClientSession + } + + // Add the db-assigned ID of the session to the session under + // the cSessionDBID key. + return sessionBkt.Put(cSessionDBID, newIndex) + }) +} + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} diff --git a/watchtower/wtdb/migration6/client_db_test.go b/watchtower/wtdb/migration6/client_db_test.go new file mode 100644 index 000000000..c4928e2f9 --- /dev/null +++ b/watchtower/wtdb/migration6/client_db_test.go @@ -0,0 +1,147 @@ +package migration6 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // pre is the expected data in the sessions bucket before the migration. + pre = map[string]interface{}{ + sessionIDToString(100): map[string]interface{}{ + string(cSessionBody): string([]byte{1, 2, 3}), + }, + sessionIDToString(222): map[string]interface{}{ + string(cSessionBody): string([]byte{4, 5, 6}), + }, + } + + // preFailCorruptDB should fail the migration due to no session body + // being found for a given session ID. + preFailCorruptDB = map[string]interface{}{ + sessionIDToString(100): "", + } + + // post is the expected session index after migration. + postIndex = map[string]interface{}{ + indexToString(1): sessionIDToString(100), + indexToString(2): sessionIDToString(222), + } + + // postSessions is the expected data in the sessions bucket after the + // migration. + postSessions = map[string]interface{}{ + sessionIDToString(100): map[string]interface{}{ + string(cSessionBody): string([]byte{1, 2, 3}), + string(cSessionDBID): indexToString(1), + }, + sessionIDToString(222): map[string]interface{}{ + string(cSessionBody): string([]byte{4, 5, 6}), + string(cSessionDBID): indexToString(2), + }, + } +) + +// TestMigrateSessionIDIndex tests that the MigrateSessionIDIndex function +// correctly adds a new session-id index to the DB and also correctly updates +// the existing session bucket. +func TestMigrateSessionIDIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + postSessions map[string]interface{} + postIndex map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + postSessions: postSessions, + postIndex: postIndex, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailCorruptDB, + }, + { + name: "no channel details", + shouldFail: false, + pre: nil, + postSessions: nil, + postIndex: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Before the migration we have a details bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, cSessionBkt, test.pre, + ) + } + + // After the migration, we should have an untouched + // summary bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + // If the migration fails, the details bucket + // should be untouched. + if test.shouldFail { + if err := migtest.VerifyDB( + tx, cSessionBkt, test.pre, + ); err != nil { + return err + } + + return nil + } + + // Else, we expect an updated summary bucket + // and a new index bucket. + err := migtest.VerifyDB( + tx, cSessionBkt, test.postSessions, + ) + if err != nil { + return err + } + + return migtest.VerifyDB( + tx, cSessionIDIndexBkt, test.postIndex, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSessionIDIndex, + test.shouldFail, + ) + }) + } +} + +func indexToString(id uint64) string { + var newIndex bytes.Buffer + err := tlv.WriteVarInt(&newIndex, id, &[8]byte{}) + if err != nil { + panic(err) + } + + return newIndex.String() +} + +func sessionIDToString(id uint64) string { + var chanID SessionID + byteOrder.PutUint64(chanID[:], id) + return chanID.String() +} diff --git a/watchtower/wtdb/migration6/codec.go b/watchtower/wtdb/migration6/codec.go new file mode 100644 index 000000000..11edbf299 --- /dev/null +++ b/watchtower/wtdb/migration6/codec.go @@ -0,0 +1,17 @@ +package migration6 + +import ( + "encoding/hex" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} diff --git a/watchtower/wtdb/migration6/log.go b/watchtower/wtdb/migration6/log.go new file mode 100644 index 000000000..e43e7d27e --- /dev/null +++ b/watchtower/wtdb/migration6/log.go @@ -0,0 +1,14 @@ +package migration6 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/migration7/client_db.go b/watchtower/wtdb/migration7/client_db.go new file mode 100644 index 000000000..0c3c4be40 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db.go @@ -0,0 +1,202 @@ +package migration7 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => chan-id => acked-index-range + cSessionBkt = []byte("client-session-bucket") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + // cChannelSummary is a sub-bucket of cChanDetailsBkt which stores the + // encoded body of ClientChanSummary. + cChannelSummary = []byte("client-channel-summary") + + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing: + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-d of a session. + cSessionDBID = []byte("client-session-db-id") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian +) + +// MigrateChannelToSessionIndex migrates the tower client DB to add an index +// from channel-to-session. This will make it easier in future to check which +// sessions have updates for which channels. +func MigrateChannelToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB to build a new " + + "channel-to-session index") + + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + chanIDsBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDsBkt == nil { + return ErrUninitializedDB + } + + // First gather all the new channel-to-session pairs that we want to + // add. + index, err := collectIndex(sessionsBkt) + if err != nil { + return err + } + + // Then persist those pairs to the db. + return persistIndex(chanDetailsBkt, chanIDsBkt, index) +} + +// collectIndex iterates through all the sessions and uses the keys in the +// cSessionAckRangeIndex bucket to collect all the channels that the session +// has updates for. The function returns a map from channel ID to session ID +// (using the db-assigned IDs for both). +func collectIndex(sessionsBkt kvdb.RBucket) (map[uint64]map[uint64]bool, + error) { + + index := make(map[uint64]map[uint64]bool) + err := sessionsBkt.ForEach(func(sessID, _ []byte) error { + sessionBkt := sessionsBkt.NestedReadBucket(sessID) + if sessionBkt == nil { + return ErrCorruptClientSession + } + + ackedRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if ackedRanges == nil { + return nil + } + + sessDBIDBytes := sessionBkt.Get(cSessionDBID) + if sessDBIDBytes == nil { + return ErrCorruptClientSession + } + + sessDBID, err := readUint64(sessDBIDBytes) + if err != nil { + return err + } + + return ackedRanges.ForEach(func(dbChanIDBytes, _ []byte) error { + dbChanID, err := readUint64(dbChanIDBytes) + if err != nil { + return err + } + + if _, ok := index[dbChanID]; !ok { + index[dbChanID] = make(map[uint64]bool) + } + + index[dbChanID][sessDBID] = true + + return nil + }) + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// persistIndex adds the channel-to-session mapping in each channel's details +// bucket. +func persistIndex(chanDetailsBkt kvdb.RwBucket, chanIDsBkt kvdb.RBucket, + index map[uint64]map[uint64]bool) error { + + for dbChanID, sessIDs := range index { + dbChanIDBytes, err := writeUint64(dbChanID) + if err != nil { + return err + } + + realChanID := chanIDsBkt.Get(dbChanIDBytes) + + chanBkt := chanDetailsBkt.NestedReadWriteBucket(realChanID) + if chanBkt == nil { + return fmt.Errorf("channel not found") + } + + sessIDsBkt, err := chanBkt.CreateBucket(cChanSessions) + if err != nil { + return err + } + + for id := range sessIDs { + sessID, err := writeUint64(id) + if err != nil { + return err + } + + err = sessIDsBkt.Put(sessID, []byte{1}) + if err != nil { + return err + } + } + } + + return nil +} + +func writeUint64(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +func readUint64(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} diff --git a/watchtower/wtdb/migration7/client_db_test.go b/watchtower/wtdb/migration7/client_db_test.go new file mode 100644 index 000000000..4f90edc47 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db_test.go @@ -0,0 +1,191 @@ +package migration7 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + }, + } + + // preFailCorruptDB should fail the migration due to no channel summary + // being found for a given channel ID. + preFailCorruptDB = map[string]interface{}{ + channelIDString(30): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(10): channelIDString(100), + uint64ToStr(20): channelIDString(222), + } + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(32), + uint64ToStr(34): uint64ToStr(34), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(30), + }, + }, + string(cSessionDBID): uint64ToStr(66), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(33): uint64ToStr(33), + }, + }, + string(cSessionDBID): uint64ToStr(77), + }, + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + uint64ToStr(77): string([]byte{1}), + }, + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + }, + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + preDetails map[string]interface{} + preSessions map[string]interface{} + preChanIndex map[string]interface{} + postDetails map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + preDetails: preDetails, + preSessions: sessions, + preChanIndex: channelIDIndex, + postDetails: postDetails, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + preDetails: preFailCorruptDB, + preSessions: sessions, + }, + { + name: "no sessions", + shouldFail: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB( + tx, cChanDetailsBkt, test.preDetails, + ) + if err != nil { + return err + } + + err = migtest.RestoreDB( + tx, cSessionBkt, test.preSessions, + ) + if err != nil { + return err + } + + return migtest.RestoreDB( + tx, cChanIDIndexBkt, test.preChanIndex, + ) + } + + after := func(tx kvdb.RwTx) error { + // If the migration fails, the details bucket + // should be untouched. + if test.shouldFail { + if err := migtest.VerifyDB( + tx, cChanDetailsBkt, + test.preDetails, + ); err != nil { + return err + } + + return nil + } + + // Else, we expect an updated details bucket + // and a new index bucket. + return migtest.VerifyDB( + tx, cChanDetailsBkt, test.postDetails, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeUint64(id) + if err != nil { + panic(err) + } + + return string(b) +} diff --git a/watchtower/wtdb/migration7/codec.go b/watchtower/wtdb/migration7/codec.go new file mode 100644 index 000000000..e94cfe67d --- /dev/null +++ b/watchtower/wtdb/migration7/codec.go @@ -0,0 +1,29 @@ +package migration7 + +import "encoding/hex" + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// String returns the string representation of the ChannelID. This is just the +// hex string encoding of the ChannelID itself. +func (c ChannelID) String() string { + return hex.EncodeToString(c[:]) +} diff --git a/watchtower/wtdb/migration7/log.go b/watchtower/wtdb/migration7/log.go new file mode 100644 index 000000000..39f28b6c0 --- /dev/null +++ b/watchtower/wtdb/migration7/log.go @@ -0,0 +1,14 @@ +package migration7 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index dbcad3715..b44ed80eb 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -10,6 +10,8 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // txMigration is a function which takes a prior outdated version of the @@ -59,6 +61,12 @@ var clientDBVersions = []version{ { txMigration: migration5.MigrateCompleteTowerToSessionIndex, }, + { + txMigration: migration6.MigrateSessionIDIndex, + }, + { + txMigration: migration7.MigrateChannelToSessionIndex, + }, } // getLatestDBVersion returns the last known database version. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 071ee7782..e004fcdaf 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore +type channel struct { + summary *wtdb.ClientChanSummary + closedHeight uint32 + sessions map[wtdb.SessionID]bool +} + // ClientDB is a mock, in-memory database or testing the watchtower client // behavior. type ClientDB struct { nextTowerID uint64 // to be used atomically mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + channels map[lnwire.ChannelID]*channel activeSessions map[wtdb.SessionID]wtdb.ClientSession ackedUpdates rangeIndexArrayMap persistedAckedUpdates rangeIndexKVStore committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower + closableSessions map[wtdb.SessionID]uint32 nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -47,9 +54,7 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make( - map[lnwire.ChannelID]wtdb.ClientChanSummary, - ), + channels: make(map[lnwire.ChannelID]*channel), activeSessions: make( map[wtdb.SessionID]wtdb.ClientSession, ), @@ -58,10 +63,11 @@ func NewClientDB() *ClientDB { committedUpdates: make( map[wtdb.SessionID][]wtdb.CommittedUpdate, ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), + closableSessions: make(map[wtdb.SessionID]uint32), } } @@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, continue } + // Add sessionID to channel. + channel, ok := m.channels[update.BackupID.ChanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + channel.sessions[*id] = true + // Remove the committed update from disk and mark the update as // acked. The tower last applied value is also recorded to send // along with the next update. @@ -538,22 +551,192 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, return wtdb.ErrCommittedUpdateNotFound } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) + for id, height := range m.closableSessions { + cs[id] = height + } + + return cs, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. +// channel summaries. Only the channels that have not yet been marked as closed +// will be loaded. func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { m.mu.Lock() defer m.mu.Unlock() summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, summary := range m.summaries { + for chanID, channel := range m.channels { + // Don't load the channel if it has been marked as closed. + if channel.closedHeight > 0 { + continue + } + summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(summary.SweepPkScript), + SweepPkScript: cloneBytes( + channel.summary.SweepPkScript, + ), } } return summaries, nil } +// MarkChannelClosed will mark a registered channel as closed by setting +// its closed-height as the given block height. It returns a list of +// session IDs for sessions that are now considered closable due to the +// close of this channel. +func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]wtdb.SessionID, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + channel, ok := m.channels[chanID] + if !ok { + return nil, wtdb.ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel details can be + // deleted. + if len(channel.sessions) == 0 { + delete(m.channels, chanID) + return nil, nil + } + + // Mark the channel as closed. + channel.closedHeight = blockHeight + + // Now iterate through all the sessions of the channel to check if any + // of them are closeable. + var closableSessions []wtdb.SessionID + for sessID := range channel.sessions { + isClosable, err := m.isSessionClosable(sessID) + if err != nil { + return nil, err + } + + if !isClosable { + continue + } + + closableSessions = append(closableSessions, sessID) + + // Add session to "closableSessions" list and add the block + // height that this last channel was closed in. This will be + // used in future to determine when we should delete the + // session. + m.closableSessions[sessID] = blockHeight + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if: +// 1) It has no un-acked updates +// 2) It is exhausted (ie it cant accept any more updates) +// 3) All the channels that it has acked-updates for are closed. +func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { + // The session is not closable if it has un-acked updates. + if len(m.committedUpdates[id]) > 0 { + return false, nil + } + + sess, ok := m.activeSessions[id] + if !ok { + return false, wtdb.ErrClientSessionNotFound + } + + // The session is not closable if it is not yet exhausted. + if sess.SeqNum != sess.Policy.MaxUpdates { + return false, nil + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + for chanID := range m.ackedUpdates[id] { + channel, ok := m.channels[chanID] + if !ok { + continue + } + + // Channel is not yet closed, and so we can not yet delete the + // session. + if channel.closedHeight == 0 { + return false, nil + } + } + + return true, nil +} + +// GetClientSession loads the ClientSession with the given ID from the DB. +func (m *ClientDB) GetClientSession(id wtdb.SessionID, + opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { + + cfg := wtdb.NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } + + session, ok := m.activeSessions[id] + if !ok { + return nil, wtdb.ErrClientSessionNotFound + } + + if cfg.PerMaxHeight != nil { + for chanID, index := range m.ackedUpdates[session.ID] { + cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) + } + } + + if cfg.PerCommittedUpdate != nil { + for _, update := range m.committedUpdates[session.ID] { + update := update + cfg.PerCommittedUpdate(&session, &update) + } + } + + return &session, nil +} + +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if it is considered closable. +func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { + m.mu.Lock() + defer m.mu.Unlock() + + _, ok := m.closableSessions[id] + if !ok { + return wtdb.ErrSessionNotClosable + } + + // For each of the channels, delete the session ID entry. + for chanID := range m.ackedUpdates[id] { + c, ok := m.channels[chanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + + delete(c.sessions, id) + } + + delete(m.closableSessions, id) + delete(m.activeSessions, id) + + return nil +} + // RegisterChannel registers a channel for use within the client database. For // now, all that is stored in the channel summary is the sweep pkscript that // we'd like any tower sweeps to pay into. In the future, this will be extended @@ -565,12 +748,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.summaries[chanID]; ok { + if _, ok := m.channels[chanID]; ok { return wtdb.ErrChannelAlreadyRegistered } - m.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), + m.channels[chanID] = &channel{ + summary: &wtdb.ClientChanSummary{ + SweepPkScript: cloneBytes(sweepPkScript), + }, + sessions: make(map[wtdb.SessionID]bool), } return nil