diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 1ce4ffb2e..92e071f13 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -250,7 +250,7 @@ type TowerClient interface { // parameters within the client. This should be called during link // startup to ensure that the client is able to support the link during // operation. - RegisterChannel(lnwire.ChannelID) error + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error // BackupState initiates a request to back up a particular revoked // state. If the method returns nil, the backup is guaranteed to be diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 2a1e49923..7cd37ac35 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -416,7 +416,9 @@ func (l *channelLink) Start() error { // If the config supplied watchtower client, ensure the channel is // registered before trying to use it during operation. if l.cfg.TowerClient != nil { - err := l.cfg.TowerClient.RegisterChannel(l.ChanID()) + err := l.cfg.TowerClient.RegisterChannel( + l.ChanID(), l.channel.State().ChanType, + ) if err != nil { return err } diff --git a/lnrpc/wtclientrpc/config.go b/lnrpc/wtclientrpc/config.go index ed0401b13..58566c19a 100644 --- a/lnrpc/wtclientrpc/config.go +++ b/lnrpc/wtclientrpc/config.go @@ -15,14 +15,6 @@ type Config struct { // Active indicates if the watchtower client is enabled. Active bool - // Client is the backing watchtower client that we'll interact with - // through the watchtower RPC subserver. - Client wtclient.Client - - // AnchorClient is the backing watchtower client for anchor channels that - // we'll interact through the watchtower RPC subserver. - AnchorClient wtclient.Client - // ClientMgr is a tower client manager that manages a set of tower // clients. ClientMgr wtclient.TowerClientManager diff --git a/peer/brontide.go b/peer/brontide.go index d7d93f4fa..1751ef5ac 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -268,12 +268,8 @@ type Config struct { // HtlcNotifier is used when creating a ChannelLink. HtlcNotifier *htlcswitch.HtlcNotifier - // TowerClient is used by legacy channels to backup revoked states. - TowerClient wtclient.Client - - // AnchorTowerClient is used by anchor channels to backup revoked - // states. - AnchorTowerClient wtclient.Client + // TowerClient is used to backup revoked states. + TowerClient wtclient.TowerClientManager // DisconnectPeer is used to disconnect this peer if the cooperative close // process fails. @@ -1040,14 +1036,8 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, return p.cfg.ChainArb.NotifyContractUpdate(*chanPoint, update) } - chanType := lnChan.State().ChanType - - // Select the appropriate tower client based on the channel type. It's - // okay if the clients are disabled altogether and these values are nil, - // as the link will check for nilness before using either. - var towerClient htlcswitch.TowerClient - switch { - case chanType.IsTaproot(): + var towerClient wtclient.TowerClientManager + if lnChan.ChanType().IsTaproot() { // Leave the tower client as nil for now until the tower client // has support for taproot channels. // @@ -1060,9 +1050,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, "are not yet taproot channel compatible", chanPoint) } - case chanType.HasAnchors(): - towerClient = p.cfg.AnchorTowerClient - default: + } else { towerClient = p.cfg.TowerClient } diff --git a/rpcserver.go b/rpcserver.go index 4685b731b..8062c4fdc 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -743,11 +743,10 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, r.cfg, s.cc, r.cfg.networkDir, macService, atpl, invoiceRegistry, s.htlcSwitch, r.cfg.ActiveNetParams.Params, s.chanRouter, routerBackend, s.nodeSigner, s.graphDB, s.chanStateDB, - s.sweeper, tower, s.towerClient, s.anchorTowerClient, - s.towerClientMgr, r.cfg.net.ResolveTCPAddr, genInvoiceFeatures, - genAmpInvoiceFeatures, s.getNodeAnnouncement, - s.updateAndBrodcastSelfNode, parseAddr, rpcsLog, - s.aliasMgr.GetPeerAlias, + s.sweeper, tower, s.towerClientMgr, r.cfg.net.ResolveTCPAddr, + genInvoiceFeatures, genAmpInvoiceFeatures, + s.getNodeAnnouncement, s.updateAndBrodcastSelfNode, parseAddr, + rpcsLog, s.aliasMgr.GetPeerAlias, ) if err != nil { return err diff --git a/server.go b/server.go index b20c692ce..076802606 100644 --- a/server.go +++ b/server.go @@ -284,10 +284,6 @@ type server struct { towerClientMgr *wtclient.Manager - towerClient wtclient.Client - - anchorTowerClient wtclient.Client - connMgr *connmgr.ConnManager sigPool *lnwallet.SigPool @@ -1577,7 +1573,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } // Register a legacy tower client. - s.towerClient, err = s.towerClientMgr.NewClient(policy) + _, err = s.towerClientMgr.NewClient(policy) if err != nil { return nil, err } @@ -1589,9 +1585,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) // Register an anchors tower client. - s.anchorTowerClient, err = s.towerClientMgr.NewClient( - anchorPolicy, - ) + _, err = s.towerClientMgr.NewClient(anchorPolicy) if err != nil { return nil, err } @@ -3782,6 +3776,17 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, } } + // If we directly set the peer.Config TowerClient member to the + // s.towerClientMgr then in the case that the s.towerClientMgr is nil, + // the peer.Config's TowerClient member will not evaluate to nil even + // though the underlying value is nil. To avoid this gotcha which can + // cause a panic, we need to explicitly pass nil to the peer.Config's + // TowerClient if needed. + var towerClient wtclient.TowerClientManager + if s.towerClientMgr != nil { + towerClient = s.towerClientMgr + } + // Now that we've established a connection, create a peer, and it to the // set of currently active peers. Configure the peer with the incoming // and outgoing broadcast deltas to prevent htlcs from being accepted or @@ -3820,8 +3825,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, Invoices: s.invoices, ChannelNotifier: s.channelNotifier, HtlcNotifier: s.htlcNotifier, - TowerClient: s.towerClient, - AnchorTowerClient: s.anchorTowerClient, + TowerClient: towerClient, DisconnectPeer: s.DisconnectPeer, GenNodeAnnouncement: func(...netann.NodeAnnModifier) ( lnwire.NodeAnnouncement, error) { diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 27bef9f92..175515de3 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -113,8 +113,6 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, - towerClient wtclient.Client, - anchorTowerClient wtclient.Client, towerClientMgr wtclient.TowerClientManager, tcpResolver lncfg.TCPResolver, genInvoiceFeatures func() *lnwire.FeatureVector, @@ -288,15 +286,9 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, case *wtclientrpc.Config: subCfgValue := extractReflectValue(subCfg) - if towerClient != nil && anchorTowerClient != nil { + if towerClientMgr != nil { subCfgValue.FieldByName("Active").Set( - reflect.ValueOf(towerClient != nil), - ) - subCfgValue.FieldByName("Client").Set( - reflect.ValueOf(towerClient), - ) - subCfgValue.FieldByName("AnchorClient").Set( - reflect.ValueOf(anchorTowerClient), + reflect.ValueOf(towerClientMgr != nil), ) subCfgValue.FieldByName("ClientMgr").Set( reflect.ValueOf(towerClientMgr), diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 871960712..bacc45ab1 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -13,14 +13,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtserver" @@ -57,9 +53,7 @@ func (c *TowerClient) genSessionFilter( activeOnly bool) wtdb.ClientSessionFilterFn { return func(session *wtdb.ClientSession) bool { - if c.cfg.Policy.IsAnchorChannel() != - session.Policy.IsAnchorChannel() { - + if c.cfg.Policy.TxPolicy != session.Policy.TxPolicy { return false } @@ -92,22 +86,6 @@ type RegisteredTower struct { ActiveSessionCandidate bool } -// Client is the primary interface used by the daemon to control a client's -// lifecycle and backup revoked states. -type Client interface { - // RegisterChannel persistently initializes any channel-dependent - // parameters within the client. This should be called during link - // startup to ensure that the client is able to support the link during - // operation. - RegisterChannel(lnwire.ChannelID) error - - // BackupState initiates a request to back up a particular revoked - // state. If the method returns nil, the backup is guaranteed to be - // successful unless the justice transaction would create dust outputs - // when trying to abide by the negotiated policy. - BackupState(chanID *lnwire.ChannelID, stateNum uint64) error -} - // BreachRetributionBuilder is a function that can be used to construct a // BreachRetribution from a channel ID and a commitment height. type BreachRetributionBuilder func(id lnwire.ChannelID, @@ -159,6 +137,8 @@ type towerClientCfg struct { // sessions recorded in the database, those sessions will be ignored and // new sessions will be requested immediately. Policy wtpolicy.Policy + + getSweepScript func(lnwire.ChannelID) ([]byte, bool) } // TowerClient is a concrete implementation of the Client interface, offering a @@ -179,11 +159,6 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *wtdb.BackupID - closableSessionQueue *sessionCloseMinHeap - - backupMu sync.Mutex - chanInfos wtdb.ChannelInfos - statTicker *time.Ticker stats *clientStats @@ -194,10 +169,6 @@ type TowerClient struct { quit chan struct{} } -// Compile-time constraint to ensure *TowerClient implements the Client -// interface. -var _ Client = (*TowerClient)(nil) - // newTowerClient initializes a new TowerClient from the provided // towerClientCfg. An error is returned if the client could not be initialized. func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { @@ -209,11 +180,6 @@ func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - chanInfos, err := cfg.DB.FetchChanInfos() - if err != nil { - return nil, err - } - queueDB := cfg.DB.GetDBQueue([]byte(identifier)) queue, err := NewDiskOverflowQueue[*wtdb.BackupID]( queueDB, cfg.MaxTasksInMemQueue, plog, @@ -223,17 +189,15 @@ func newTowerClient(cfg *towerClientCfg) (*TowerClient, error) { } c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: queue, - activeSessions: newSessionQueueSet(), - chanInfos: chanInfos, - closableSessionQueue: newSessionCloseMinHeap(), - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(clientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - quit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: queue, + activeSessions: newSessionQueueSet(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(clientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + quit: make(chan struct{}), } candidateTowers := newTowerListIterator() @@ -410,65 +374,9 @@ func (c *TowerClient) start() error { } } - chanSub, err := c.cfg.SubscribeChannelEvents() - if err != nil { - return err - } - - // Iterate over the list of registered channels and check if - // any of them can be marked as closed. - for id := range c.chanInfos { - isClosed, closedHeight, err := c.isChannelClosed(id) - if err != nil { - return err - } - - if !isClosed { - continue - } - - _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) - if err != nil { - c.log.Errorf("could not mark channel(%s) as "+ - "closed: %v", id, err) - - continue - } - - // Since the channel has been marked as closed, we can - // also remove it from the channel summaries map. - delete(c.chanInfos, id) - } - - // Load all closable sessions. - closableSessions, err := c.cfg.DB.ListClosableSessions() - if err != nil { - return err - } - - err = c.trackClosableSessions(closableSessions) - if err != nil { - return err - } - - c.wg.Add(1) - go c.handleChannelCloses(chanSub) - - // Subscribe to new block events. - blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( - nil, - ) - if err != nil { - return err - } - - c.wg.Add(1) - go c.handleClosableSessions(blockEvents) - - // Now start the session negotiator, which will allow us to - // request new session as soon as the backupDispatcher starts - // up. - err = c.negotiator.Start() + // Now start the session negotiator, which will allow us to request new + // session as soon as the backupDispatcher starts up. + err := c.negotiator.Start() if err != nil { return err } @@ -538,84 +446,15 @@ func (c *TowerClient) stop() error { return returnErr } -// RegisterChannel persistently initializes any channel-dependent parameters -// within the client. This should be called during link startup to ensure that -// the client is able to support the link during operation. -func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // If a pkscript for this channel already exists, the channel has been - // previously registered. - if _, ok := c.chanInfos[chanID]; ok { - return nil - } - - // Otherwise, generate a new sweep pkscript used to sweep funds for this - // channel. - pkScript, err := c.cfg.NewAddress() - if err != nil { - return err - } - - // Persist the sweep pkscript so that restarts will not introduce - // address inflation when the channel is reregistered after a restart. - err = c.cfg.DB.RegisterChannel(chanID, pkScript) - if err != nil { - return err - } - - // Finally, cache the pkscript in our in-memory cache to avoid db - // lookups for the remainder of the daemon's execution. - c.chanInfos[chanID] = &wtdb.ChannelInfo{ - ClientChanSummary: wtdb.ClientChanSummary{ - SweepPkScript: pkScript, - }, - } - - return nil -} - -// BackupState initiates a request to back up a particular revoked state. If the +// backupState initiates a request to back up a particular revoked state. If the // method returns nil, the backup is guaranteed to be successful unless the: // - justice transaction would create dust outputs when trying to abide by the // negotiated policy, or // - breached outputs contain too little value to sweep at the target sweep // fee rate. -func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, +func (c *TowerClient) backupState(chanID *lnwire.ChannelID, stateNum uint64) error { - // Make sure that this channel is registered with the tower client. - c.backupMu.Lock() - info, ok := c.chanInfos[*chanID] - if !ok { - c.backupMu.Unlock() - - return ErrUnregisteredChannel - } - - // Ignore backups that have already been presented to the client. - var duplicate bool - info.MaxHeight.WhenSome(func(maxHeight uint64) { - if stateNum <= maxHeight { - duplicate = true - } - }) - if duplicate { - c.backupMu.Unlock() - - c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ - "height=%d", chanID, stateNum) - - return nil - } - - // This backup has a higher commit height than any known backup for this - // channel. We'll update our tip so that we won't accept it again if the - // link flaps. - c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) - c.backupMu.Unlock() - id := &wtdb.BackupID{ ChanID: *chanID, CommitHeight: stateNum, @@ -667,215 +506,10 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } -// handleChannelCloses listens for channel close events and marks channels as -// closed in the DB. -// -// NOTE: This method MUST be run as a goroutine. -func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { - defer c.wg.Done() - - c.log.Debugf("Starting channel close handler") - defer c.log.Debugf("Stopping channel close handler") - - for { - select { - case update, ok := <-chanSub.Updates(): - if !ok { - c.log.Debugf("Channel notifier has exited") - return - } - - // We only care about channel-close events. - event, ok := update.(channelnotifier.ClosedChannelEvent) - if !ok { - continue - } - - chanID := lnwire.NewChanIDFromOutPoint( - &event.CloseSummary.ChanPoint, - ) - - c.log.Debugf("Received ClosedChannelEvent for "+ - "channel: %s", chanID) - - err := c.handleClosedChannel( - chanID, event.CloseSummary.CloseHeight, - ) - if err != nil { - c.log.Errorf("Could not handle channel close "+ - "event for channel(%s): %v", chanID, - err) - } - - case <-c.quit: - return - } - } -} - -// handleClosedChannel handles the closure of a single channel. It will mark the -// channel as closed in the DB, then it will handle all the sessions that are -// now closable due to the channel closure. -func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, - closeHeight uint32) error { - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // We only care about channels registered with the tower client. - if _, ok := c.chanInfos[chanID]; !ok { - return nil - } - - c.log.Debugf("Marking channel(%s) as closed", chanID) - - sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) - if err != nil { - return fmt.Errorf("could not mark channel(%s) as closed: %w", - chanID, err) - } - - closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) - for _, sess := range sessions { - closableSessions[sess] = closeHeight - } - - c.log.Debugf("Tracking %d new closable sessions as a result of "+ - "closing channel %s", len(closableSessions), chanID) - - err = c.trackClosableSessions(closableSessions) - if err != nil { - return fmt.Errorf("could not track closable sessions: %w", err) - } - - delete(c.chanInfos, chanID) - - return nil -} - -// handleClosableSessions listens for new block notifications. For each block, -// it checks the closableSessionQueue to see if there is a closable session with -// a delete-height smaller than or equal to the new block, if there is then the -// tower is informed that it can delete the session, and then we also delete it -// from our DB. -func (c *TowerClient) handleClosableSessions( - blocksChan *chainntnfs.BlockEpochEvent) { - - defer c.wg.Done() - - c.log.Debug("Starting closable sessions handler") - defer c.log.Debug("Stopping closable sessions handler") - - for { - select { - case newBlock := <-blocksChan.Epochs: - if newBlock == nil { - return - } - - height := uint32(newBlock.Height) - for { - select { - case <-c.quit: - return - default: - } - - // If there are no closable sessions that we - // need to handle, then we are done and can - // reevaluate when the next block comes. - item := c.closableSessionQueue.Top() - if item == nil { - break - } - - // If there is closable session but the delete - // height we have set for it is after the - // current block height, then our work is done. - if item.deleteHeight > height { - break - } - - // Otherwise, we pop this item from the heap - // and handle it. - c.closableSessionQueue.Pop() - - // Stop the session and remove it from the - // in-memory set. - err := c.activeSessions.StopAndRemove( - item.sessionID, - ) - if err != nil { - c.log.Errorf("could not remove "+ - "session(%s) from in-memory "+ - "set: %v", item.sessionID, err) - - return - } - - // Fetch the session from the DB so that we can - // extract the Tower info. - sess, err := c.cfg.DB.GetClientSession( - item.sessionID, - ) - if err != nil { - c.log.Errorf("error calling "+ - "GetClientSession for "+ - "session %s: %v", - item.sessionID, err) - - continue - } - - err = c.deleteSessionFromTower(sess) - if err != nil { - c.log.Errorf("error deleting "+ - "session %s from tower: %v", - sess.ID, err) - - continue - } - - err = c.cfg.DB.DeleteSession(item.sessionID) - if err != nil { - c.log.Errorf("could not delete "+ - "session(%s) from DB: %w", - sess.ID, err) - - continue - } - } - - case <-c.quit: - return - } - } -} - -// trackClosableSessions takes in a map of session IDs to the earliest block -// height at which the session should be deleted. For each of the sessions, -// a random delay is added to the block height and the session is added to the -// closableSessionQueue. -func (c *TowerClient) trackClosableSessions( - sessions map[wtdb.SessionID]uint32) error { - - // For each closable session, add a random delay to its close - // height and add it to the closableSessionQueue. - for sID, blockHeight := range sessions { - delay, err := newRandomDelay(c.cfg.SessionCloseRange) - if err != nil { - return err - } - - deleteHeight := blockHeight + delay - - c.closableSessionQueue.Push(&sessionCloseItem{ - sessionID: sID, - deleteHeight: deleteHeight, - }) - } - - return nil +// stopAndRemoveSession stops the session with the given ID and removes it from +// the in-memory active sessions set. +func (c *TowerClient) stopAndRemoveSession(id wtdb.SessionID) error { + return c.activeSessions.StopAndRemove(id) } // deleteSessionFromTower dials the tower that we created the session with and @@ -1154,19 +788,15 @@ func (c *TowerClient) backupDispatcher() { // that are rejected because the active sessionQueue is full will be cached as // the prevTask, and should be reprocessed after obtaining a new sessionQueue. func (c *TowerClient) processTask(task *wtdb.BackupID) { - c.backupMu.Lock() - summary, ok := c.chanInfos[task.ChanID] + script, ok := c.cfg.getSweepScript(task.ChanID) if !ok { - c.backupMu.Unlock() - log.Infof("not processing task for unregistered channel: %s", task.ChanID) return } - c.backupMu.Unlock() - backupTask := newBackupTask(*task, summary.SweepPkScript) + backupTask := newBackupTask(*task, script) status, accepted := c.sessionQueue.AcceptTask(backupTask) if accepted { @@ -1410,22 +1040,6 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } -// isChanClosed can be used to check if the channel with the given ID has been -// closed. If it has been, the block height in which its closing transaction was -// mined will also be returned. -func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, - error) { - - chanSum, err := c.cfg.FetchClosedChannel(id) - if errors.Is(err, channeldb.ErrClosedChannelNotFound) { - return false, 0, nil - } else if err != nil { - return false, 0, err - } - - return true, chanSum.CloseHeight, nil -} - // addTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index f1deb697a..452404919 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -403,7 +403,6 @@ type testHarness struct { clientDB *wtdb.ClientDB clientCfg *wtclient.Config clientPolicy wtpolicy.Policy - client wtclient.Client server *serverHarness net *mockNet @@ -564,7 +563,7 @@ func (h *testHarness) startClient() { h.clientMgr, err = wtclient.NewManager(h.clientCfg) require.NoError(h.t, err) - h.client, err = h.clientMgr.NewClient(h.clientPolicy) + _, err = h.clientMgr.NewClient(h.clientPolicy) require.NoError(h.t, err) require.NoError(h.t, h.clientMgr.Start()) require.NoError(h.t, h.clientMgr.AddTower(towerAddr)) @@ -668,7 +667,7 @@ func (h *testHarness) registerChannel(id uint64) { h.t.Helper() chanID := chanIDFromInt(id) - err := h.client.RegisterChannel(chanID) + err := h.clientMgr.RegisterChannel(chanID, channeldb.SingleFunderBit) require.NoError(h.t, err) } @@ -709,8 +708,8 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution.RevokedStateNum) - require.ErrorIs(h.t, expErr, err) + err := h.clientMgr.BackupState(&chanID, retribution.RevokedStateNum) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -1254,6 +1253,7 @@ var clientTests = []clientTest{ // Restart the client and allow it to process the // committed update. h.startClient() + h.registerChannel(chanID) // Wait for the committed update to be accepted by the // tower. @@ -1555,6 +1555,7 @@ var clientTests = []clientTest{ // Restart the client with a new policy. h.clientPolicy.MaxUpdates = 20 h.startClient() + h.registerChannel(chanID) // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -1605,6 +1606,7 @@ var clientTests = []clientTest{ // maintained across restarts. require.NoError(h.t, h.clientMgr.Stop()) h.startClient() + h.registerChannel(chanID) // Try to back up the full range of retributions. Only // the second half should actually be sent. @@ -2126,6 +2128,7 @@ var clientTests = []clientTest{ require.NoError(h.t, h.clientMgr.Stop()) h.server.start() h.startClient() + h.registerChannel(chanID) // Back up a few more states. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -2520,7 +2523,7 @@ var clientTests = []clientTest{ // Wait for channel to be "unregistered". chanID := chanIDFromInt(chanIDInt) err = wait.Predicate(func() bool { - err := h.client.BackupState(&chanID, 0) + err := h.clientMgr.BackupState(&chanID, 0) return errors.Is( err, wtclient.ErrUnregisteredChannel, diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 61386548e..be0042f43 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -1,6 +1,7 @@ package wtclient import ( + "errors" "fmt" "net" "sync" @@ -10,6 +11,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/subscribe" @@ -50,6 +53,18 @@ type TowerClientManager interface { // LookupTower retrieves a registered watchtower through its public key. LookupTower(*btcec.PublicKey, ...wtdb.ClientSessionListOption) ( map[blob.Type]*RegisteredTower, error) + + // RegisterChannel persistently initializes any channel-dependent + // parameters within the client. This should be called during link + // startup to ensure that the client is able to support the link during + // operation. + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error + + // BackupState initiates a request to back up a particular revoked + // state. If the method returns nil, the backup is guaranteed to be + // successful unless the justice transaction would create dust outputs + // when trying to abide by the negotiated policy. + BackupState(chanID *lnwire.ChannelID, stateNum uint64) error } // Config provides the TowerClient with access to the resources it requires to @@ -143,6 +158,15 @@ type Manager struct { clients map[blob.Type]*TowerClient clientsMu sync.Mutex + + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos + chanBlobType map[lnwire.ChannelID]blob.Type + + closableSessionQueue *sessionCloseMinHeap + + wg sync.WaitGroup + quit chan struct{} } var _ TowerClientManager = (*Manager)(nil) @@ -163,9 +187,18 @@ func NewManager(config *Config) (*Manager, error) { cfg.WriteTimeout = DefaultWriteTimeout } + chanInfos, err := cfg.DB.FetchChanInfos() + if err != nil { + return nil, err + } + return &Manager{ - cfg: &cfg, - clients: make(map[blob.Type]*TowerClient), + cfg: &cfg, + clients: make(map[blob.Type]*TowerClient), + chanBlobType: make(map[lnwire.ChannelID]blob.Type), + chanInfos: chanInfos, + closableSessionQueue: newSessionCloseMinHeap(), + quit: make(chan struct{}), }, nil } @@ -182,8 +215,9 @@ func (m *Manager) NewClient(policy wtpolicy.Policy) (*TowerClient, error) { } cfg := &towerClientCfg{ - Config: m.cfg, - Policy: policy, + Config: m.cfg, + Policy: policy, + getSweepScript: m.getSweepScript, } client, err := newTowerClient(cfg) @@ -200,6 +234,71 @@ func (m *Manager) NewClient(policy wtpolicy.Policy) (*TowerClient, error) { func (m *Manager) Start() error { var returnErr error m.started.Do(func() { + chanSub, err := m.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + + return + } + + // Iterate over the list of registered channels and check if any + // of them can be marked as closed. + for id := range m.chanInfos { + isClosed, closedHeight, err := m.isChannelClosed(id) + if err != nil { + returnErr = err + + return + } + + if !isClosed { + continue + } + + _, err = m.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(m.chanInfos, id) + } + + // Load all closable sessions. + closableSessions, err := m.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + + return + } + + err = m.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleChannelCloses(chanSub) + + // Subscribe to new block events. + blockEvents, err := m.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleClosableSessions(blockEvents) + m.clientsMu.Lock() defer m.clientsMu.Unlock() @@ -221,6 +320,9 @@ func (m *Manager) Stop() error { m.clientsMu.Lock() defer m.clientsMu.Unlock() + close(m.quit) + m.wg.Wait() + for _, client := range m.clients { if err := client.stop(); err != nil { returnErr = err @@ -402,3 +504,387 @@ func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) { return client.policy(), nil } + +// RegisterChannel persistently initializes any channel-dependent parameters +// within the client. This should be called during link startup to ensure that +// the client is able to support the link during operation. +func (m *Manager) RegisterChannel(id lnwire.ChannelID, + chanType channeldb.ChannelType) error { + + blobType := blob.TypeAltruistCommit + if chanType.HasAnchors() { + blobType = blob.TypeAltruistAnchorCommit + } + + m.clientsMu.Lock() + if _, ok := m.clients[blobType]; !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // If a pkscript for this channel already exists, the channel has been + // previously registered. + if _, ok := m.chanInfos[id]; ok { + // Keep track of which blob type this channel will use for + // updates. + m.chanBlobType[id] = blobType + + return nil + } + + // Otherwise, generate a new sweep pkscript used to sweep funds for this + // channel. + pkScript, err := m.cfg.NewAddress() + if err != nil { + return err + } + + // Persist the sweep pkscript so that restarts will not introduce + // address inflation when the channel is reregistered after a restart. + err = m.cfg.DB.RegisterChannel(id, pkScript) + if err != nil { + return err + } + + // Finally, cache the pkscript in our in-memory cache to avoid db + // lookups for the remainder of the daemon's execution. + m.chanInfos[id] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, + } + + // Keep track of which blob type this channel will use for updates. + m.chanBlobType[id] = blobType + + return nil +} + +// BackupState initiates a request to back up a particular revoked state. If the +// method returns nil, the backup is guaranteed to be successful unless the +// justice transaction would create dust outputs when trying to abide by the +// negotiated policy. +func (m *Manager) BackupState(chanID *lnwire.ChannelID, stateNum uint64) error { + select { + case <-m.quit: + return ErrClientExiting + default: + } + + // Make sure that this channel is registered with the tower client. + m.backupMu.Lock() + info, ok := m.chanInfos[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + + // Ignore backups that have already been presented to the client. + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { + m.backupMu.Unlock() + + log.Debugf("Ignoring duplicate backup for chanid=%v at "+ + "height=%d", chanID, stateNum) + + return nil + } + + // This backup has a higher commit height than any known backup for this + // channel. We'll update our tip so that we won't accept it again if the + // link flaps. + m.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) + + blobType, ok := m.chanBlobType[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + m.backupMu.Unlock() + + m.clientsMu.Lock() + client, ok := m.clients[blobType] + if !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + return client.backupState(chanID, stateNum) +} + +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (m *Manager) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := m.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (m *Manager) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(m.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + m.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (m *Manager) handleChannelCloses(chanSub subscribe.Subscription) { + defer m.wg.Done() + + log.Debugf("Starting channel close handler") + defer log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := m.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-m.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. +func (m *Manager) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := m.chanInfos[chanID]; !ok { + return nil + } + + log.Debugf("Marking channel(%s) as closed", chanID) + + sessions, err := m.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = m.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + + delete(m.chanInfos, chanID) + + return nil +} + +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (m *Manager) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer m.wg.Done() + + log.Debug("Starting closable sessions handler") + defer log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-m.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := m.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + m.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := m.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + // get appropriate client. + m.clientsMu.Lock() + client, ok := m.clients[sess.Policy.BlobType] + if !ok { + m.clientsMu.Unlock() + log.Errorf("no client currently " + + "active for the session type") + + continue + } + m.clientsMu.Unlock() + + clientName, err := client.policy().BlobType. + Identifier() + if err != nil { + log.Errorf("could not get client "+ + "identifier: %v", err) + + continue + } + + // Stop the session and remove it from the + // in-memory set. + err = client.stopAndRemoveSession( + item.sessionID, + ) + if err != nil { + log.Errorf("could not remove "+ + "session(%s) from in-memory "+ + "set of the %s client: %v", + item.sessionID, clientName, err) + + continue + } + + err = client.deleteSessionFromTower(sess) + if err != nil { + log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = m.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-m.quit: + return + } + } +} + +func (m *Manager) getSweepScript(id lnwire.ChannelID) ([]byte, bool) { + m.backupMu.Lock() + defer m.backupMu.Unlock() + + summary, ok := m.chanInfos[id] + if !ok { + return nil, false + } + + return summary.SweepPkScript, true +}