From fcfdf699e3aea3688a396d2ef4fc17a843042409 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 16 May 2023 15:22:45 +0200 Subject: [PATCH] multi: move BackupState and RegisterChannel to Manager This commit moves over the last two methods, `RegisterChannel` and `BackupState` from the `Client` to the `Manager` interface. With this change, we no longer need to pass around the individual clients around and now only need to pass the manager around. To do this change, all the goroutines that handle channel closes, closable sessions needed to be moved to the Manager and so a large part of this commit is just moving this code from the TowerClient to the Manager. --- htlcswitch/interfaces.go | 2 +- htlcswitch/link.go | 4 +- lnrpc/wtclientrpc/config.go | 8 - peer/brontide.go | 22 +- rpcserver.go | 9 +- server.go | 24 +- subrpcserver_config.go | 12 +- watchtower/wtclient/client.go | 432 ++----------------------- watchtower/wtclient/client_test.go | 15 +- watchtower/wtclient/manager.go | 494 ++++++++++++++++++++++++++++- 10 files changed, 551 insertions(+), 471 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 1ce4ffb2e..92e071f13 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -250,7 +250,7 @@ type TowerClient interface { // parameters within the client. This should be called during link // startup to ensure that the client is able to support the link during // operation. - RegisterChannel(lnwire.ChannelID) error + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error // BackupState initiates a request to back up a particular revoked // state. If the method returns nil, the backup is guaranteed to be diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 2a1e49923..7cd37ac35 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -416,7 +416,9 @@ func (l *channelLink) Start() error { // If the config supplied watchtower client, ensure the channel is // registered before trying to use it during operation. if l.cfg.TowerClient != nil { - err := l.cfg.TowerClient.RegisterChannel(l.ChanID()) + err := l.cfg.TowerClient.RegisterChannel( + l.ChanID(), l.channel.State().ChanType, + ) if err != nil { return err } diff --git a/lnrpc/wtclientrpc/config.go b/lnrpc/wtclientrpc/config.go index ed0401b13..58566c19a 100644 --- a/lnrpc/wtclientrpc/config.go +++ b/lnrpc/wtclientrpc/config.go @@ -15,14 +15,6 @@ type Config struct { // Active indicates if the watchtower client is enabled. Active bool - // Client is the backing watchtower client that we'll interact with - // through the watchtower RPC subserver. - Client wtclient.Client - - // AnchorClient is the backing watchtower client for anchor channels that - // we'll interact through the watchtower RPC subserver. - AnchorClient wtclient.Client - // ClientMgr is a tower client manager that manages a set of tower // clients. ClientMgr wtclient.TowerClientManager diff --git a/peer/brontide.go b/peer/brontide.go index d7d93f4fa..1751ef5ac 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -268,12 +268,8 @@ type Config struct { // HtlcNotifier is used when creating a ChannelLink. HtlcNotifier *htlcswitch.HtlcNotifier - // TowerClient is used by legacy channels to backup revoked states. - TowerClient wtclient.Client - - // AnchorTowerClient is used by anchor channels to backup revoked - // states. - AnchorTowerClient wtclient.Client + // TowerClient is used to backup revoked states. + TowerClient wtclient.TowerClientManager // DisconnectPeer is used to disconnect this peer if the cooperative close // process fails. @@ -1040,14 +1036,8 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, return p.cfg.ChainArb.NotifyContractUpdate(*chanPoint, update) } - chanType := lnChan.State().ChanType - - // Select the appropriate tower client based on the channel type. It's - // okay if the clients are disabled altogether and these values are nil, - // as the link will check for nilness before using either. - var towerClient htlcswitch.TowerClient - switch { - case chanType.IsTaproot(): + var towerClient wtclient.TowerClientManager + if lnChan.ChanType().IsTaproot() { // Leave the tower client as nil for now until the tower client // has support for taproot channels. // @@ -1060,9 +1050,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, "are not yet taproot channel compatible", chanPoint) } - case chanType.HasAnchors(): - towerClient = p.cfg.AnchorTowerClient - default: + } else { towerClient = p.cfg.TowerClient } diff --git a/rpcserver.go b/rpcserver.go index 4685b731b..8062c4fdc 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -743,11 +743,10 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, r.cfg, s.cc, r.cfg.networkDir, macService, atpl, invoiceRegistry, s.htlcSwitch, r.cfg.ActiveNetParams.Params, s.chanRouter, routerBackend, s.nodeSigner, s.graphDB, s.chanStateDB, - s.sweeper, tower, s.towerClient, s.anchorTowerClient, - s.towerClientMgr, r.cfg.net.ResolveTCPAddr, genInvoiceFeatures, - genAmpInvoiceFeatures, s.getNodeAnnouncement, - s.updateAndBrodcastSelfNode, parseAddr, rpcsLog, - s.aliasMgr.GetPeerAlias, + s.sweeper, tower, s.towerClientMgr, r.cfg.net.ResolveTCPAddr, + genInvoiceFeatures, genAmpInvoiceFeatures, + s.getNodeAnnouncement, s.updateAndBrodcastSelfNode, parseAddr, + rpcsLog, s.aliasMgr.GetPeerAlias, ) if err != nil { return err diff --git a/server.go b/server.go index b20c692ce..076802606 100644 --- a/server.go +++ b/server.go @@ -284,10 +284,6 @@ type server struct { towerClientMgr *wtclient.Manager - towerClient wtclient.Client - - anchorTowerClient wtclient.Client - connMgr *connmgr.ConnManager sigPool *lnwallet.SigPool @@ -1577,7 +1573,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } // Register a legacy tower client. - s.towerClient, err = s.towerClientMgr.NewClient(policy) + _, err = s.towerClientMgr.NewClient(policy) if err != nil { return nil, err } @@ -1589,9 +1585,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) // Register an anchors tower client. - s.anchorTowerClient, err = s.towerClientMgr.NewClient( - anchorPolicy, - ) + _, err = s.towerClientMgr.NewClient(anchorPolicy) if err != nil { return nil, err } @@ -3782,6 +3776,17 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, } } + // If we directly set the peer.Config TowerClient member to the + // s.towerClientMgr then in the case that the s.towerClientMgr is nil, + // the peer.Config's TowerClient member will not evaluate to nil even + // though the underlying value is nil. To avoid this gotcha which can + // cause a panic, we need to explicitly pass nil to the peer.Config's + // TowerClient if needed. + var towerClient wtclient.TowerClientManager + if s.towerClientMgr != nil { + towerClient = s.towerClientMgr + } + // Now that we've established a connection, create a peer, and it to the // set of currently active peers. Configure the peer with the incoming // and outgoing broadcast deltas to prevent htlcs from being accepted or @@ -3820,8 +3825,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, Invoices: s.invoices, ChannelNotifier: s.channelNotifier, HtlcNotifier: s.htlcNotifier, - TowerClient: s.towerClient, - AnchorTowerClient: s.anchorTowerClient, + TowerClient: towerClient, DisconnectPeer: s.DisconnectPeer, GenNodeAnnouncement: func(...netann.NodeAnnModifier) ( lnwire.NodeAnnouncement, error) { diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 27bef9f92..175515de3 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -113,8 +113,6 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, - towerClient wtclient.Client, - anchorTowerClient wtclient.Client, towerClientMgr wtclient.TowerClientManager, tcpResolver lncfg.TCPResolver, genInvoiceFeatures func() *lnwire.FeatureVector, @@ -288,15 +286,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, case *wtclientrpc.Config: subCfgValue := extractReflectValue(subCfg) - if towerClient != nil && anchorTowerClient != nil { + if towerClientMgr != nil { subCfgValue.FieldByName("Active").Set( - reflect.ValueOf(towerClient != nil), - ) - subCfgValue.FieldByName("Client").Set( - reflect.ValueOf(towerClient), - ) - subCfgValue.FieldByName("AnchorClient").Set( - reflect.ValueOf(anchorTowerClient), + reflect.ValueOf(towerClientMgr != nil), ) subCfgValue.FieldByName("ClientMgr").Set( reflect.ValueOf(towerClientMgr), diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 871960712..bacc45ab1 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -13,14 +13,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtserver" @@ -57,9 +53,7 @@ func (c *TowerClient) genSessionFilter( activeOnly bool) wtdb.ClientSessionFilterFn { return func(session *wtdb.ClientSession) bool { - if c.cfg.Policy.IsAnchorChannel() != - session.Policy.IsAnchorChannel() { - + if c.cfg.Policy.TxPolicy != session.Policy.TxPolicy { return false } @@ -92,22 +86,6 @@ type RegisteredTower struct { ActiveSessionCandidate bool } -// Client is the primary interface used by the daemon to control a client's -// lifecycle and backup revoked states. -type Client interface { - // RegisterChannel persistently initializes any channel-dependent - // parameters within the client. This should be called during link - // startup to ensure that the client is able to support the link during - // operation. - RegisterChannel(lnwire.ChannelID) error - - // BackupState initiates a request to back up a particular revoked - // state. If the method returns nil, the backup is guaranteed to be - // successful unless the justice transaction would create dust outputs - // when trying to abide by the negotiated policy. - BackupState(chanID *lnwire.ChannelID, stateNum uint64) error -} - // BreachRetributionBuilder is a function that can be used to construct a // BreachRetribution from a channel ID and a commitment height. type BreachRetributionBuilder func(id lnwire.ChannelID, @@ -159,6 +137,8 @@ type towerClientCfg struct { // sessions recorded in the database, those sessions will be ignored and // new sessions will be requested immediately. Policy wtpolicy.Policy + + getSweepScript func(lnwire.ChannelID) ([]byte, bool) } // TowerClient is a concrete implementation of the Client interface, offering a @@ -179,11 +159,6 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *wtdb.BackupID - closableSessionQueue *sessionCloseMinHeap - - backupMu sync.Mutex - chanInfos wtdb.ChannelInfos - statTicker *time.Ticker stats *clientStats @@ -194,10 +169,6 @@ type TowerClient struct { quit chan struct{} } -// Compile-time constraint to ensure *TowerClient implements the Client -// interface. -var _ Client = (*TowerClient)(nil) - // newTowerClient initializes a new TowerClient from the provided // towerClientCfg. An error is returned if the client could not be initialized. func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { @@ -209,11 +180,6 @@ func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - chanInfos, err := cfg.DB.FetchChanInfos() - if err != nil { - return nil, err - } - queueDB := cfg.DB.GetDBQueue([]byte(identifier)) queue, err := NewDiskOverflowQueue[*wtdb.BackupID]( queueDB, cfg.MaxTasksInMemQueue, plog, @@ -223,17 +189,15 @@ func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { } c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: queue, - activeSessions: newSessionQueueSet(), - chanInfos: chanInfos, - closableSessionQueue: newSessionCloseMinHeap(), - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(clientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - quit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: queue, + activeSessions: newSessionQueueSet(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(clientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + quit: make(chan struct{}), } candidateTowers := newTowerListIterator() @@ -410,65 +374,9 @@ func (c *TowerClient) start() error { } } - chanSub, err := c.cfg.SubscribeChannelEvents() - if err != nil { - return err - } - - // Iterate over the list of registered channels and check if - // any of them can be marked as closed. - for id := range c.chanInfos { - isClosed, closedHeight, err := c.isChannelClosed(id) - if err != nil { - return err - } - - if !isClosed { - continue - } - - _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) - if err != nil { - c.log.Errorf("could not mark channel(%s) as "+ - "closed: %v", id, err) - - continue - } - - // Since the channel has been marked as closed, we can - // also remove it from the channel summaries map. - delete(c.chanInfos, id) - } - - // Load all closable sessions. - closableSessions, err := c.cfg.DB.ListClosableSessions() - if err != nil { - return err - } - - err = c.trackClosableSessions(closableSessions) - if err != nil { - return err - } - - c.wg.Add(1) - go c.handleChannelCloses(chanSub) - - // Subscribe to new block events. - blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( - nil, - ) - if err != nil { - return err - } - - c.wg.Add(1) - go c.handleClosableSessions(blockEvents) - - // Now start the session negotiator, which will allow us to - // request new session as soon as the backupDispatcher starts - // up. - err = c.negotiator.Start() + // Now start the session negotiator, which will allow us to request new + // session as soon as the backupDispatcher starts up. + err := c.negotiator.Start() if err != nil { return err } @@ -538,84 +446,15 @@ func (c *TowerClient) stop() error { return returnErr } -// RegisterChannel persistently initializes any channel-dependent parameters -// within the client. This should be called during link startup to ensure that -// the client is able to support the link during operation. -func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // If a pkscript for this channel already exists, the channel has been - // previously registered. - if _, ok := c.chanInfos[chanID]; ok { - return nil - } - - // Otherwise, generate a new sweep pkscript used to sweep funds for this - // channel. - pkScript, err := c.cfg.NewAddress() - if err != nil { - return err - } - - // Persist the sweep pkscript so that restarts will not introduce - // address inflation when the channel is reregistered after a restart. - err = c.cfg.DB.RegisterChannel(chanID, pkScript) - if err != nil { - return err - } - - // Finally, cache the pkscript in our in-memory cache to avoid db - // lookups for the remainder of the daemon's execution. - c.chanInfos[chanID] = &wtdb.ChannelInfo{ - ClientChanSummary: wtdb.ClientChanSummary{ - SweepPkScript: pkScript, - }, - } - - return nil -} - -// BackupState initiates a request to back up a particular revoked state. If the +// backupState initiates a request to back up a particular revoked state. If the // method returns nil, the backup is guaranteed to be successful unless the: // - justice transaction would create dust outputs when trying to abide by the // negotiated policy, or // - breached outputs contain too little value to sweep at the target sweep // fee rate. -func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, +func (c *TowerClient) backupState(chanID *lnwire.ChannelID, stateNum uint64) error { - // Make sure that this channel is registered with the tower client. - c.backupMu.Lock() - info, ok := c.chanInfos[*chanID] - if !ok { - c.backupMu.Unlock() - - return ErrUnregisteredChannel - } - - // Ignore backups that have already been presented to the client. - var duplicate bool - info.MaxHeight.WhenSome(func(maxHeight uint64) { - if stateNum <= maxHeight { - duplicate = true - } - }) - if duplicate { - c.backupMu.Unlock() - - c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ - "height=%d", chanID, stateNum) - - return nil - } - - // This backup has a higher commit height than any known backup for this - // channel. We'll update our tip so that we won't accept it again if the - // link flaps. - c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) - c.backupMu.Unlock() - id := &wtdb.BackupID{ ChanID: *chanID, CommitHeight: stateNum, @@ -667,215 +506,10 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } -// handleChannelCloses listens for channel close events and marks channels as -// closed in the DB. -// -// NOTE: This method MUST be run as a goroutine. -func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { - defer c.wg.Done() - - c.log.Debugf("Starting channel close handler") - defer c.log.Debugf("Stopping channel close handler") - - for { - select { - case update, ok := <-chanSub.Updates(): - if !ok { - c.log.Debugf("Channel notifier has exited") - return - } - - // We only care about channel-close events. - event, ok := update.(channelnotifier.ClosedChannelEvent) - if !ok { - continue - } - - chanID := lnwire.NewChanIDFromOutPoint( - &event.CloseSummary.ChanPoint, - ) - - c.log.Debugf("Received ClosedChannelEvent for "+ - "channel: %s", chanID) - - err := c.handleClosedChannel( - chanID, event.CloseSummary.CloseHeight, - ) - if err != nil { - c.log.Errorf("Could not handle channel close "+ - "event for channel(%s): %v", chanID, - err) - } - - case <-c.quit: - return - } - } -} - -// handleClosedChannel handles the closure of a single channel. It will mark the -// channel as closed in the DB, then it will handle all the sessions that are -// now closable due to the channel closure. -func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, - closeHeight uint32) error { - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // We only care about channels registered with the tower client. - if _, ok := c.chanInfos[chanID]; !ok { - return nil - } - - c.log.Debugf("Marking channel(%s) as closed", chanID) - - sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) - if err != nil { - return fmt.Errorf("could not mark channel(%s) as closed: %w", - chanID, err) - } - - closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) - for _, sess := range sessions { - closableSessions[sess] = closeHeight - } - - c.log.Debugf("Tracking %d new closable sessions as a result of "+ - "closing channel %s", len(closableSessions), chanID) - - err = c.trackClosableSessions(closableSessions) - if err != nil { - return fmt.Errorf("could not track closable sessions: %w", err) - } - - delete(c.chanInfos, chanID) - - return nil -} - -// handleClosableSessions listens for new block notifications. For each block, -// it checks the closableSessionQueue to see if there is a closable session with -// a delete-height smaller than or equal to the new block, if there is then the -// tower is informed that it can delete the session, and then we also delete it -// from our DB. -func (c *TowerClient) handleClosableSessions( - blocksChan *chainntnfs.BlockEpochEvent) { - - defer c.wg.Done() - - c.log.Debug("Starting closable sessions handler") - defer c.log.Debug("Stopping closable sessions handler") - - for { - select { - case newBlock := <-blocksChan.Epochs: - if newBlock == nil { - return - } - - height := uint32(newBlock.Height) - for { - select { - case <-c.quit: - return - default: - } - - // If there are no closable sessions that we - // need to handle, then we are done and can - // reevaluate when the next block comes. - item := c.closableSessionQueue.Top() - if item == nil { - break - } - - // If there is closable session but the delete - // height we have set for it is after the - // current block height, then our work is done. - if item.deleteHeight > height { - break - } - - // Otherwise, we pop this item from the heap - // and handle it. - c.closableSessionQueue.Pop() - - // Stop the session and remove it from the - // in-memory set. - err := c.activeSessions.StopAndRemove( - item.sessionID, - ) - if err != nil { - c.log.Errorf("could not remove "+ - "session(%s) from in-memory "+ - "set: %v", item.sessionID, err) - - return - } - - // Fetch the session from the DB so that we can - // extract the Tower info. - sess, err := c.cfg.DB.GetClientSession( - item.sessionID, - ) - if err != nil { - c.log.Errorf("error calling "+ - "GetClientSession for "+ - "session %s: %v", - item.sessionID, err) - - continue - } - - err = c.deleteSessionFromTower(sess) - if err != nil { - c.log.Errorf("error deleting "+ - "session %s from tower: %v", - sess.ID, err) - - continue - } - - err = c.cfg.DB.DeleteSession(item.sessionID) - if err != nil { - c.log.Errorf("could not delete "+ - "session(%s) from DB: %w", - sess.ID, err) - - continue - } - } - - case <-c.quit: - return - } - } -} - -// trackClosableSessions takes in a map of session IDs to the earliest block -// height at which the session should be deleted. For each of the sessions, -// a random delay is added to the block height and the session is added to the -// closableSessionQueue. -func (c *TowerClient) trackClosableSessions( - sessions map[wtdb.SessionID]uint32) error { - - // For each closable session, add a random delay to its close - // height and add it to the closableSessionQueue. - for sID, blockHeight := range sessions { - delay, err := newRandomDelay(c.cfg.SessionCloseRange) - if err != nil { - return err - } - - deleteHeight := blockHeight + delay - - c.closableSessionQueue.Push(&sessionCloseItem{ - sessionID: sID, - deleteHeight: deleteHeight, - }) - } - - return nil +// stopAndRemoveSession stops the session with the given ID and removes it from +// the in-memory active sessions set. +func (c *TowerClient) stopAndRemoveSession(id wtdb.SessionID) error { + return c.activeSessions.StopAndRemove(id) } // deleteSessionFromTower dials the tower that we created the session with and @@ -1154,19 +788,15 @@ func (c *TowerClient) backupDispatcher() { // that are rejected because the active sessionQueue is full will be cached as // the prevTask, and should be reprocessed after obtaining a new sessionQueue. func (c *TowerClient) processTask(task *wtdb.BackupID) { - c.backupMu.Lock() - summary, ok := c.chanInfos[task.ChanID] + script, ok := c.cfg.getSweepScript(task.ChanID) if !ok { - c.backupMu.Unlock() - log.Infof("not processing task for unregistered channel: %s", task.ChanID) return } - c.backupMu.Unlock() - backupTask := newBackupTask(*task, summary.SweepPkScript) + backupTask := newBackupTask(*task, script) status, accepted := c.sessionQueue.AcceptTask(backupTask) if accepted { @@ -1410,22 +1040,6 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } -// isChanClosed can be used to check if the channel with the given ID has been -// closed. If it has been, the block height in which its closing transaction was -// mined will also be returned. -func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, - error) { - - chanSum, err := c.cfg.FetchClosedChannel(id) - if errors.Is(err, channeldb.ErrClosedChannelNotFound) { - return false, 0, nil - } else if err != nil { - return false, 0, err - } - - return true, chanSum.CloseHeight, nil -} - // addTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index f1deb697a..452404919 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -403,7 +403,6 @@ type testHarness struct { clientDB *wtdb.ClientDB clientCfg *wtclient.Config clientPolicy wtpolicy.Policy - client wtclient.Client server *serverHarness net *mockNet @@ -564,7 +563,7 @@ func (h *testHarness) startClient() { h.clientMgr, err = wtclient.NewManager(h.clientCfg) require.NoError(h.t, err) - h.client, err = h.clientMgr.NewClient(h.clientPolicy) + _, err = h.clientMgr.NewClient(h.clientPolicy) require.NoError(h.t, err) require.NoError(h.t, h.clientMgr.Start()) require.NoError(h.t, h.clientMgr.AddTower(towerAddr)) @@ -668,7 +667,7 @@ func (h *testHarness) registerChannel(id uint64) { h.t.Helper() chanID := chanIDFromInt(id) - err := h.client.RegisterChannel(chanID) + err := h.clientMgr.RegisterChannel(chanID, channeldb.SingleFunderBit) require.NoError(h.t, err) } @@ -709,8 +708,8 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution.RevokedStateNum) - require.ErrorIs(h.t, expErr, err) + err := h.clientMgr.BackupState(&chanID, retribution.RevokedStateNum) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -1254,6 +1253,7 @@ var clientTests = []clientTest{ // Restart the client and allow it to process the // committed update. h.startClient() + h.registerChannel(chanID) // Wait for the committed update to be accepted by the // tower. @@ -1555,6 +1555,7 @@ var clientTests = []clientTest{ // Restart the client with a new policy. h.clientPolicy.MaxUpdates = 20 h.startClient() + h.registerChannel(chanID) // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -1605,6 +1606,7 @@ var clientTests = []clientTest{ // maintained across restarts. require.NoError(h.t, h.clientMgr.Stop()) h.startClient() + h.registerChannel(chanID) // Try to back up the full range of retributions. Only // the second half should actually be sent. @@ -2126,6 +2128,7 @@ var clientTests = []clientTest{ require.NoError(h.t, h.clientMgr.Stop()) h.server.start() h.startClient() + h.registerChannel(chanID) // Back up a few more states. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -2520,7 +2523,7 @@ var clientTests = []clientTest{ // Wait for channel to be "unregistered". chanID := chanIDFromInt(chanIDInt) err = wait.Predicate(func() bool { - err := h.client.BackupState(&chanID, 0) + err := h.clientMgr.BackupState(&chanID, 0) return errors.Is( err, wtclient.ErrUnregisteredChannel, diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 61386548e..be0042f43 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -1,6 +1,7 @@ package wtclient import ( + "errors" "fmt" "net" "sync" @@ -10,6 +11,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/subscribe" @@ -50,6 +53,18 @@ type TowerClientManager interface { // LookupTower retrieves a registered watchtower through its public key. LookupTower(*btcec.PublicKey, ...wtdb.ClientSessionListOption) ( map[blob.Type]*RegisteredTower, error) + + // RegisterChannel persistently initializes any channel-dependent + // parameters within the client. This should be called during link + // startup to ensure that the client is able to support the link during + // operation. + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error + + // BackupState initiates a request to back up a particular revoked + // state. If the method returns nil, the backup is guaranteed to be + // successful unless the justice transaction would create dust outputs + // when trying to abide by the negotiated policy. + BackupState(chanID *lnwire.ChannelID, stateNum uint64) error } // Config provides the TowerClient with access to the resources it requires to @@ -143,6 +158,15 @@ type Manager struct { clients map[blob.Type]*TowerClient clientsMu sync.Mutex + + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos + chanBlobType map[lnwire.ChannelID]blob.Type + + closableSessionQueue *sessionCloseMinHeap + + wg sync.WaitGroup + quit chan struct{} } var _ TowerClientManager = (*Manager)(nil) @@ -163,9 +187,18 @@ func NewManager(config *Config) (*Manager, error) { cfg.WriteTimeout = DefaultWriteTimeout } + chanInfos, err := cfg.DB.FetchChanInfos() + if err != nil { + return nil, err + } + return &Manager{ - cfg: &cfg, - clients: make(map[blob.Type]*TowerClient), + cfg: &cfg, + clients: make(map[blob.Type]*TowerClient), + chanBlobType: make(map[lnwire.ChannelID]blob.Type), + chanInfos: chanInfos, + closableSessionQueue: newSessionCloseMinHeap(), + quit: make(chan struct{}), }, nil } @@ -182,8 +215,9 @@ func (m *Manager) NewClient(policy wtpolicy.Policy) (*TowerClient, error) { } cfg := &towerClientCfg{ - Config: m.cfg, - Policy: policy, + Config: m.cfg, + Policy: policy, + getSweepScript: m.getSweepScript, } client, err := newTowerClient(cfg) @@ -200,6 +234,71 @@ func (m *Manager) NewClient(policy wtpolicy.Policy) (*TowerClient, error) { func (m *Manager) Start() error { var returnErr error m.started.Do(func() { + chanSub, err := m.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + + return + } + + // Iterate over the list of registered channels and check if any + // of them can be marked as closed. + for id := range m.chanInfos { + isClosed, closedHeight, err := m.isChannelClosed(id) + if err != nil { + returnErr = err + + return + } + + if !isClosed { + continue + } + + _, err = m.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(m.chanInfos, id) + } + + // Load all closable sessions. + closableSessions, err := m.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + + return + } + + err = m.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleChannelCloses(chanSub) + + // Subscribe to new block events. + blockEvents, err := m.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleClosableSessions(blockEvents) + m.clientsMu.Lock() defer m.clientsMu.Unlock() @@ -221,6 +320,9 @@ func (m *Manager) Stop() error { m.clientsMu.Lock() defer m.clientsMu.Unlock() + close(m.quit) + m.wg.Wait() + for _, client := range m.clients { if err := client.stop(); err != nil { returnErr = err @@ -402,3 +504,387 @@ func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) { return client.policy(), nil } + +// RegisterChannel persistently initializes any channel-dependent parameters +// within the client. This should be called during link startup to ensure that +// the client is able to support the link during operation. +func (m *Manager) RegisterChannel(id lnwire.ChannelID, + chanType channeldb.ChannelType) error { + + blobType := blob.TypeAltruistCommit + if chanType.HasAnchors() { + blobType = blob.TypeAltruistAnchorCommit + } + + m.clientsMu.Lock() + if _, ok := m.clients[blobType]; !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // If a pkscript for this channel already exists, the channel has been + // previously registered. + if _, ok := m.chanInfos[id]; ok { + // Keep track of which blob type this channel will use for + // updates. + m.chanBlobType[id] = blobType + + return nil + } + + // Otherwise, generate a new sweep pkscript used to sweep funds for this + // channel. + pkScript, err := m.cfg.NewAddress() + if err != nil { + return err + } + + // Persist the sweep pkscript so that restarts will not introduce + // address inflation when the channel is reregistered after a restart. + err = m.cfg.DB.RegisterChannel(id, pkScript) + if err != nil { + return err + } + + // Finally, cache the pkscript in our in-memory cache to avoid db + // lookups for the remainder of the daemon's execution. + m.chanInfos[id] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, + } + + // Keep track of which blob type this channel will use for updates. + m.chanBlobType[id] = blobType + + return nil +} + +// BackupState initiates a request to back up a particular revoked state. If the +// method returns nil, the backup is guaranteed to be successful unless the +// justice transaction would create dust outputs when trying to abide by the +// negotiated policy. +func (m *Manager) BackupState(chanID *lnwire.ChannelID, stateNum uint64) error { + select { + case <-m.quit: + return ErrClientExiting + default: + } + + // Make sure that this channel is registered with the tower client. + m.backupMu.Lock() + info, ok := m.chanInfos[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + + // Ignore backups that have already been presented to the client. + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { + m.backupMu.Unlock() + + log.Debugf("Ignoring duplicate backup for chanid=%v at "+ + "height=%d", chanID, stateNum) + + return nil + } + + // This backup has a higher commit height than any known backup for this + // channel. We'll update our tip so that we won't accept it again if the + // link flaps. + m.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) + + blobType, ok := m.chanBlobType[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + m.backupMu.Unlock() + + m.clientsMu.Lock() + client, ok := m.clients[blobType] + if !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + return client.backupState(chanID, stateNum) +} + +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (m *Manager) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := m.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (m *Manager) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(m.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + m.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (m *Manager) handleChannelCloses(chanSub subscribe.Subscription) { + defer m.wg.Done() + + log.Debugf("Starting channel close handler") + defer log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := m.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-m.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. +func (m *Manager) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := m.chanInfos[chanID]; !ok { + return nil + } + + log.Debugf("Marking channel(%s) as closed", chanID) + + sessions, err := m.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = m.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + + delete(m.chanInfos, chanID) + + return nil +} + +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (m *Manager) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer m.wg.Done() + + log.Debug("Starting closable sessions handler") + defer log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-m.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := m.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + m.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := m.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + // get appropriate client. + m.clientsMu.Lock() + client, ok := m.clients[sess.Policy.BlobType] + if !ok { + m.clientsMu.Unlock() + log.Errorf("no client currently " + + "active for the session type") + + continue + } + m.clientsMu.Unlock() + + clientName, err := client.policy().BlobType. + Identifier() + if err != nil { + log.Errorf("could not get client "+ + "identifier: %v", err) + + continue + } + + // Stop the session and remove it from the + // in-memory set. + err = client.stopAndRemoveSession( + item.sessionID, + ) + if err != nil { + log.Errorf("could not remove "+ + "session(%s) from in-memory "+ + "set of the %s client: %v", + item.sessionID, clientName, err) + + continue + } + + err = client.deleteSessionFromTower(sess) + if err != nil { + log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = m.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-m.quit: + return + } + } +} + +func (m *Manager) getSweepScript(id lnwire.ChannelID) ([]byte, bool) { + m.backupMu.Lock() + defer m.backupMu.Unlock() + + summary, ok := m.chanInfos[id] + if !ok { + return nil, false + } + + return summary.SweepPkScript, true +}