diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 877f707eb..59ded9784 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -129,6 +129,10 @@ In particular, the complexity involved in the lifecycle loop has been decoupled into logical steps, with each step having its own responsibility, making it easier to reason about the payment flow. + +* [Add a watchtower tower client + multiplexer](https://github.com/lightningnetwork/lnd/pull/7702) to manage + tower clients of different types. ## Breaking Changes ## Performance Improvements diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 1ce4ffb2e..92e071f13 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -250,7 +250,7 @@ type TowerClient interface { // parameters within the client. This should be called during link // startup to ensure that the client is able to support the link during // operation. - RegisterChannel(lnwire.ChannelID) error + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error // BackupState initiates a request to back up a particular revoked // state. If the method returns nil, the backup is guaranteed to be diff --git a/htlcswitch/link.go b/htlcswitch/link.go index e302cf3d8..629c46aeb 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -416,7 +416,9 @@ func (l *channelLink) Start() error { // If the config supplied watchtower client, ensure the channel is // registered before trying to use it during operation. if l.cfg.TowerClient != nil { - err := l.cfg.TowerClient.RegisterChannel(l.ChanID()) + err := l.cfg.TowerClient.RegisterChannel( + l.ChanID(), l.channel.State().ChanType, + ) if err != nil { return err } diff --git a/lnrpc/wtclientrpc/config.go b/lnrpc/wtclientrpc/config.go index 9127c0846..b95fb1197 100644 --- a/lnrpc/wtclientrpc/config.go +++ b/lnrpc/wtclientrpc/config.go @@ -15,13 +15,9 @@ type Config struct { // Active indicates if the watchtower client is enabled. Active bool - // Client is the backing watchtower client that we'll interact with - // through the watchtower RPC subserver. - Client wtclient.Client - - // AnchorClient is the backing watchtower client for anchor channels that - // we'll interact through the watchtower RPC subserver. - AnchorClient wtclient.Client + // ClientMgr is a tower client manager that manages a set of tower + // clients. + ClientMgr wtclient.ClientManager // Resolver is a custom resolver that will be used to resolve watchtower // addresses to ensure we don't leak any information when running over diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index aa7e1ae21..bf7f9a1da 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -16,9 +16,9 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower" + "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "google.golang.org/grpc" "gopkg.in/macaroon-bakery.v2/bakery" ) @@ -208,11 +208,7 @@ func (c *WatchtowerClient) AddTower(ctx context.Context, Address: addr, } - // TODO(conner): make atomic via multiplexed client - if err := c.cfg.Client.AddTower(towerAddr); err != nil { - return nil, err - } - if err := c.cfg.AnchorClient.AddTower(towerAddr); err != nil { + if err := c.cfg.ClientMgr.AddTower(towerAddr); err != nil { return nil, err } @@ -247,12 +243,7 @@ func (c *WatchtowerClient) RemoveTower(ctx context.Context, } } - // TODO(conner): make atomic via multiplexed client - err = c.cfg.Client.RemoveTower(pubKey, addr) - if err != nil { - return nil, err - } - err = c.cfg.AnchorClient.RemoveTower(pubKey, addr) + err = c.cfg.ClientMgr.RemoveTower(pubKey, addr) if err != nil { return nil, err } @@ -272,23 +263,7 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, req.IncludeSessions, req.ExcludeExhaustedSessions, ) - anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...) - if err != nil { - return nil, err - } - - // Collect all the anchor client towers. - rpcTowers := make(map[wtdb.TowerID]*Tower) - for _, tower := range anchorTowers { - rpcTower := marshallTower( - tower, PolicyType_ANCHOR, req.IncludeSessions, - ackCounts, committedUpdateCounts, - ) - - rpcTowers[tower.ID] = rpcTower - } - - legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...) + towersPerBlobType, err := c.cfg.ClientMgr.RegisteredTowers(opts...) if err != nil { return nil, err } @@ -296,20 +271,32 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, // Collect all the legacy client towers. If it has any of the same // towers that the anchors client has, then just add the session info // for the legacy client to the existing tower. - for _, tower := range legacyTowers { - rpcTower := marshallTower( - tower, PolicyType_LEGACY, req.IncludeSessions, - ackCounts, committedUpdateCounts, - ) - - t, ok := rpcTowers[tower.ID] - if !ok { - rpcTowers[tower.ID] = rpcTower - continue + rpcTowers := make(map[wtdb.TowerID]*Tower) + for blobType, towers := range towersPerBlobType { + policyType := PolicyType_LEGACY + if blobType.IsAnchorChannel() { + policyType = PolicyType_ANCHOR } - t.SessionInfo = append(t.SessionInfo, rpcTower.SessionInfo...) - t.Sessions = append(t.Sessions, rpcTower.Sessions...) + for _, tower := range towers { + rpcTower := marshallTower( + tower, policyType, req.IncludeSessions, + ackCounts, committedUpdateCounts, + ) + + t, ok := rpcTowers[tower.ID] + if !ok { + rpcTowers[tower.ID] = rpcTower + continue + } + + t.SessionInfo = append( + t.SessionInfo, rpcTower.SessionInfo..., + ) + t.Sessions = append( + t.Sessions, rpcTower.Sessions..., + ) + } } towers := make([]*Tower, 0, len(rpcTowers)) @@ -337,40 +324,42 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context, req.IncludeSessions, req.ExcludeExhaustedSessions, ) - // Get the tower and its sessions from anchors client. - tower, err := c.cfg.AnchorClient.LookupTower(pubKey, opts...) - if err != nil { - return nil, err - } - rpcTower := marshallTower( - tower, PolicyType_ANCHOR, req.IncludeSessions, ackCounts, - committedUpdateCounts, - ) - - // Get the tower and its sessions from legacy client. - tower, err = c.cfg.Client.LookupTower(pubKey, opts...) + towersPerBlobType, err := c.cfg.ClientMgr.LookupTower(pubKey, opts...) if err != nil { return nil, err } - rpcLegacyTower := marshallTower( - tower, PolicyType_LEGACY, req.IncludeSessions, ackCounts, - committedUpdateCounts, - ) + var resTower *Tower + for blobType, tower := range towersPerBlobType { + policyType := PolicyType_LEGACY + if blobType.IsAnchorChannel() { + policyType = PolicyType_ANCHOR + } - if !bytes.Equal(rpcTower.Pubkey, rpcLegacyTower.Pubkey) { - return nil, fmt.Errorf("legacy and anchor clients returned " + - "inconsistent results for the given tower") + rpcTower := marshallTower( + tower, policyType, req.IncludeSessions, + ackCounts, committedUpdateCounts, + ) + + if resTower == nil { + resTower = rpcTower + continue + } + + if !bytes.Equal(rpcTower.Pubkey, resTower.Pubkey) { + return nil, fmt.Errorf("tower clients returned " + + "inconsistent results for the given tower") + } + + resTower.SessionInfo = append( + resTower.SessionInfo, rpcTower.SessionInfo..., + ) + resTower.Sessions = append( + resTower.Sessions, rpcTower.Sessions..., + ) } - rpcTower.SessionInfo = append( - rpcTower.SessionInfo, rpcLegacyTower.SessionInfo..., - ) - rpcTower.Sessions = append( - rpcTower.Sessions, rpcLegacyTower.Sessions..., - ) - - return rpcTower, nil + return resTower, nil } // constructFunctionalOptions is a helper function that constructs a list of @@ -422,30 +411,14 @@ func constructFunctionalOptions(includeSessions, } // Stats returns the in-memory statistics of the client since startup. -func (c *WatchtowerClient) Stats(ctx context.Context, - req *StatsRequest) (*StatsResponse, error) { +func (c *WatchtowerClient) Stats(_ context.Context, + _ *StatsRequest) (*StatsResponse, error) { if err := c.isActive(); err != nil { return nil, err } - clientStats := []wtclient.ClientStats{ - c.cfg.Client.Stats(), - c.cfg.AnchorClient.Stats(), - } - - var stats wtclient.ClientStats - for i := range clientStats { - // Grab a reference to the slice index rather than copying bc - // ClientStats contains a lock which cannot be copied by value. - stat := &clientStats[i] - - stats.NumTasksAccepted += stat.NumTasksAccepted - stats.NumTasksIneligible += stat.NumTasksIneligible - stats.NumTasksPending += stat.NumTasksPending - stats.NumSessionsAcquired += stat.NumSessionsAcquired - stats.NumSessionsExhausted += stat.NumSessionsExhausted - } + stats := c.cfg.ClientMgr.Stats() return &StatsResponse{ NumBackups: uint32(stats.NumTasksAccepted), @@ -464,17 +437,22 @@ func (c *WatchtowerClient) Policy(ctx context.Context, return nil, err } - var policy wtpolicy.Policy + var blobType blob.Type switch req.PolicyType { case PolicyType_LEGACY: - policy = c.cfg.Client.Policy() + blobType = blob.TypeAltruistCommit case PolicyType_ANCHOR: - policy = c.cfg.AnchorClient.Policy() + blobType = blob.TypeAltruistAnchorCommit default: return nil, fmt.Errorf("unknown policy type: %v", req.PolicyType) } + policy, err := c.cfg.ClientMgr.Policy(blobType) + if err != nil { + return nil, err + } + return &PolicyResponse{ MaxUpdates: uint32(policy.MaxUpdates), SweepSatPerVbyte: uint32(policy.SweepFeeRate.FeePerVByte()), diff --git a/peer/brontide.go b/peer/brontide.go index 35cdbd847..244899792 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -268,12 +268,8 @@ type Config struct { // HtlcNotifier is used when creating a ChannelLink. HtlcNotifier *htlcswitch.HtlcNotifier - // TowerClient is used by legacy channels to backup revoked states. - TowerClient wtclient.Client - - // AnchorTowerClient is used by anchor channels to backup revoked - // states. - AnchorTowerClient wtclient.Client + // TowerClient is used to backup revoked states. + TowerClient wtclient.ClientManager // DisconnectPeer is used to disconnect this peer if the cooperative close // process fails. @@ -1040,14 +1036,8 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, return p.cfg.ChainArb.NotifyContractUpdate(*chanPoint, update) } - chanType := lnChan.State().ChanType - - // Select the appropriate tower client based on the channel type. It's - // okay if the clients are disabled altogether and these values are nil, - // as the link will check for nilness before using either. - var towerClient htlcswitch.TowerClient - switch { - case chanType.IsTaproot(): + var towerClient wtclient.ClientManager + if lnChan.ChanType().IsTaproot() { // Leave the tower client as nil for now until the tower client // has support for taproot channels. // @@ -1060,9 +1050,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, "are not yet taproot channel compatible", chanPoint) } - case chanType.HasAnchors(): - towerClient = p.cfg.AnchorTowerClient - default: + } else { towerClient = p.cfg.TowerClient } diff --git a/rpcserver.go b/rpcserver.go index c612bb4ad..383cab3e2 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -744,11 +744,10 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, r.cfg, s.cc, r.cfg.networkDir, macService, atpl, invoiceRegistry, s.htlcSwitch, r.cfg.ActiveNetParams.Params, s.chanRouter, routerBackend, s.nodeSigner, s.graphDB, s.chanStateDB, - s.sweeper, tower, s.towerClient, s.anchorTowerClient, - r.cfg.net.ResolveTCPAddr, genInvoiceFeatures, - genAmpInvoiceFeatures, s.getNodeAnnouncement, - s.updateAndBrodcastSelfNode, parseAddr, rpcsLog, - s.aliasMgr.GetPeerAlias, + s.sweeper, tower, s.towerClientMgr, r.cfg.net.ResolveTCPAddr, + genInvoiceFeatures, genAmpInvoiceFeatures, + s.getNodeAnnouncement, s.updateAndBrodcastSelfNode, parseAddr, + rpcsLog, s.aliasMgr.GetPeerAlias, ) if err != nil { return err diff --git a/server.go b/server.go index fcc489435..fd53ee5e6 100644 --- a/server.go +++ b/server.go @@ -282,9 +282,7 @@ type server struct { sphinx *hop.OnionProcessor - towerClient wtclient.Client - - anchorTowerClient wtclient.Client + towerClientMgr *wtclient.Manager connMgr *connmgr.ConnManager @@ -1548,40 +1546,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID - s.towerClient, err = wtclient.New(&wtclient.Config{ - FetchClosedChannel: fetchClosedChannel, - BuildBreachRetribution: buildBreachRetribution, - SessionCloseRange: cfg.WtClient.SessionCloseRange, - ChainNotifier: s.cc.ChainNotifier, - SubscribeChannelEvents: func() (subscribe.Subscription, - error) { - - return s.channelNotifier. - SubscribeChannelEvents() - }, - Signer: cc.Wallet.Cfg.Signer, - NewAddress: newSweepPkScriptGen(cc.Wallet), - SecretKeyRing: s.cc.KeyRing, - Dial: cfg.net.Dial, - AuthDial: authDial, - DB: dbs.TowerClientDB, - Policy: policy, - ChainHash: *s.cfg.ActiveNetParams.GenesisHash, - MinBackoff: 10 * time.Second, - MaxBackoff: 5 * time.Minute, - MaxTasksInMemQueue: cfg.WtClient.MaxTasksInMemQueue, - }) - if err != nil { - return nil, err - } - // Copy the policy for legacy channels and set the blob flag // signalling support for anchor channels. anchorPolicy := policy - anchorPolicy.TxPolicy.BlobType |= - blob.Type(blob.FlagAnchorChannel) + anchorPolicy.BlobType |= blob.Type(blob.FlagAnchorChannel) - s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ + s.towerClientMgr, err = wtclient.NewManager(&wtclient.Config{ FetchClosedChannel: fetchClosedChannel, BuildBreachRetribution: buildBreachRetribution, SessionCloseRange: cfg.WtClient.SessionCloseRange, @@ -1598,12 +1568,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Dial: cfg.net.Dial, AuthDial: authDial, DB: dbs.TowerClientDB, - Policy: anchorPolicy, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, MinBackoff: 10 * time.Second, MaxBackoff: 5 * time.Minute, MaxTasksInMemQueue: cfg.WtClient.MaxTasksInMemQueue, - }) + }, policy, anchorPolicy) if err != nil { return nil, err } @@ -1925,19 +1894,12 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.htlcNotifier.Stop) - if s.towerClient != nil { - if err := s.towerClient.Start(); err != nil { + if s.towerClientMgr != nil { + if err := s.towerClientMgr.Start(); err != nil { startErr = err return } - cleanup = cleanup.add(s.towerClient.Stop) - } - if s.anchorTowerClient != nil { - if err := s.anchorTowerClient.Start(); err != nil { - startErr = err - return - } - cleanup = cleanup.add(s.anchorTowerClient.Stop) + cleanup = cleanup.add(s.towerClientMgr.Stop) } if err := s.sweeper.Start(); err != nil { @@ -2310,16 +2272,10 @@ func (s *server) Stop() error { // client which will reliably flush all queued states to the // tower. If this is halted for any reason, the force quit timer // will kick in and abort to allow this method to return. - if s.towerClient != nil { - if err := s.towerClient.Stop(); err != nil { + if s.towerClientMgr != nil { + if err := s.towerClientMgr.Stop(); err != nil { srvrLog.Warnf("Unable to shut down tower "+ - "client: %v", err) - } - } - if s.anchorTowerClient != nil { - if err := s.anchorTowerClient.Stop(); err != nil { - srvrLog.Warnf("Unable to shut down anchor "+ - "tower client: %v", err) + "client manager: %v", err) } } @@ -3807,6 +3763,17 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, } } + // If we directly set the peer.Config TowerClient member to the + // s.towerClientMgr then in the case that the s.towerClientMgr is nil, + // the peer.Config's TowerClient member will not evaluate to nil even + // though the underlying value is nil. To avoid this gotcha which can + // cause a panic, we need to explicitly pass nil to the peer.Config's + // TowerClient if needed. + var towerClient wtclient.ClientManager + if s.towerClientMgr != nil { + towerClient = s.towerClientMgr + } + // Now that we've established a connection, create a peer, and it to the // set of currently active peers. Configure the peer with the incoming // and outgoing broadcast deltas to prevent htlcs from being accepted or @@ -3845,8 +3812,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, Invoices: s.invoices, ChannelNotifier: s.channelNotifier, HtlcNotifier: s.htlcNotifier, - TowerClient: s.towerClient, - AnchorTowerClient: s.anchorTowerClient, + TowerClient: towerClient, DisconnectPeer: s.DisconnectPeer, GenNodeAnnouncement: func(...netann.NodeAnnModifier) ( lnwire.NodeAnnouncement, error) { diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 7706dfd27..360d7fa31 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -113,8 +113,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, - towerClient wtclient.Client, - anchorTowerClient wtclient.Client, + towerClientMgr wtclient.ClientManager, tcpResolver lncfg.TCPResolver, genInvoiceFeatures func() *lnwire.FeatureVector, genAmpInvoiceFeatures func() *lnwire.FeatureVector, @@ -287,15 +286,12 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, case *wtclientrpc.Config: subCfgValue := extractReflectValue(subCfg) - if towerClient != nil && anchorTowerClient != nil { + if towerClientMgr != nil { subCfgValue.FieldByName("Active").Set( - reflect.ValueOf(towerClient != nil), + reflect.ValueOf(towerClientMgr != nil), ) - subCfgValue.FieldByName("Client").Set( - reflect.ValueOf(towerClient), - ) - subCfgValue.FieldByName("AnchorClient").Set( - reflect.ValueOf(anchorTowerClient), + subCfgValue.FieldByName("ClientMgr").Set( + reflect.ValueOf(towerClientMgr), ) } subCfgValue.FieldByName("Resolver").Set( diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f9036e87f..24d974ced 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -11,19 +11,12 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/subscribe" - "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtserver" @@ -56,13 +49,11 @@ const ( // genSessionFilter constructs a filter that can be used to select sessions only // if they match the policy of the client (namely anchor vs legacy). If // activeOnly is set, then only active sessions will be returned. -func (c *TowerClient) genSessionFilter( +func (c *client) genSessionFilter( activeOnly bool) wtdb.ClientSessionFilterFn { return func(session *wtdb.ClientSession) bool { - if c.cfg.Policy.IsAnchorChannel() != - session.Policy.IsAnchorChannel() { - + if c.cfg.Policy.TxPolicy != session.Policy.TxPolicy { return false } @@ -95,157 +86,18 @@ type RegisteredTower struct { ActiveSessionCandidate bool } -// Client is the primary interface used by the daemon to control a client's -// lifecycle and backup revoked states. -type Client interface { - // AddTower adds a new watchtower reachable at the given address and - // considers it for new sessions. If the watchtower already exists, then - // any new addresses included will be considered when dialing it for - // session negotiations and backups. - AddTower(*lnwire.NetAddress) error - - // RemoveTower removes a watchtower from being considered for future - // session negotiations and from being used for any subsequent backups - // until it's added again. If an address is provided, then this call - // only serves as a way of removing the address from the watchtower - // instead. - RemoveTower(*btcec.PublicKey, net.Addr) error - - // RegisteredTowers retrieves the list of watchtowers registered with - // the client. - RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower, - error) - - // LookupTower retrieves a registered watchtower through its public key. - LookupTower(*btcec.PublicKey, - ...wtdb.ClientSessionListOption) (*RegisteredTower, error) - - // Stats returns the in-memory statistics of the client since startup. - Stats() ClientStats - - // Policy returns the active client policy configuration. - Policy() wtpolicy.Policy - - // RegisterChannel persistently initializes any channel-dependent - // parameters within the client. This should be called during link - // startup to ensure that the client is able to support the link during - // operation. - RegisterChannel(lnwire.ChannelID) error - - // BackupState initiates a request to back up a particular revoked - // state. If the method returns nil, the backup is guaranteed to be - // successful unless the justice transaction would create dust outputs - // when trying to abide by the negotiated policy. - BackupState(chanID *lnwire.ChannelID, stateNum uint64) error - - // Start initializes the watchtower client, allowing it process requests - // to backup revoked channel states. - Start() error - - // Stop attempts a graceful shutdown of the watchtower client. In doing - // so, it will attempt to flush the pipeline and deliver any queued - // states to the tower before exiting. - Stop() error -} - -// Config provides the TowerClient with access to the resources it requires to -// perform its duty. All nillable fields must be non-nil for the tower to be -// initialized properly. -type Config struct { - // Signer provides access to the wallet so that the client can sign - // justice transactions that spend from a remote party's commitment - // transaction. - Signer input.Signer - - // SubscribeChannelEvents can be used to subscribe to channel event - // notifications. - SubscribeChannelEvents func() (subscribe.Subscription, error) - - // FetchClosedChannel can be used to fetch the info about a closed - // channel. If the channel is not found or not yet closed then - // channeldb.ErrClosedChannelNotFound will be returned. - FetchClosedChannel func(cid lnwire.ChannelID) ( - *channeldb.ChannelCloseSummary, error) - - // ChainNotifier can be used to subscribe to block notifications. - ChainNotifier chainntnfs.ChainNotifier - - // BuildBreachRetribution is a function closure that allows the client - // fetch the breach retribution info for a certain channel at a certain - // revoked commitment height. - BuildBreachRetribution BreachRetributionBuilder - - // NewAddress generates a new on-chain sweep pkscript. - NewAddress func() ([]byte, error) - - // SecretKeyRing is used to derive the session keys used to communicate - // with the tower. The client only stores the KeyLocators internally so - // that we never store private keys on disk. - SecretKeyRing ECDHKeyRing - - // Dial connects to an addr using the specified net and returns the - // connection object. - Dial tor.DialFunc - - // AuthDialer establishes a brontide connection over an onion or clear - // network. - AuthDial AuthDialer - - // DB provides access to the client's stable storage medium. - DB DB - - // Policy is the session policy the client will propose when creating - // new sessions with the tower. If the policy differs from any active - // sessions recorded in the database, those sessions will be ignored and - // new sessions will be requested immediately. - Policy wtpolicy.Policy - - // ChainHash identifies the chain that the client is on and for which - // the tower must be watching to monitor for breaches. - ChainHash chainhash.Hash - - // ReadTimeout is the duration we will wait during a read before - // breaking out of a blocking read. If the value is less than or equal - // to zero, the default will be used instead. - ReadTimeout time.Duration - - // WriteTimeout is the duration we will wait during a write before - // breaking out of a blocking write. If the value is less than or equal - // to zero, the default will be used instead. - WriteTimeout time.Duration - - // MinBackoff defines the initial backoff applied to connections with - // watchtowers. Subsequent backoff durations will grow exponentially up - // until MaxBackoff. - MinBackoff time.Duration - - // MaxBackoff defines the maximum backoff applied to connections with - // watchtowers. If the exponential backoff produces a timeout greater - // than this value, the backoff will be clamped to MaxBackoff. - MaxBackoff time.Duration - - // SessionCloseRange is the range over which we will generate a random - // number of blocks to delay closing a session after its last channel - // has been closed. - SessionCloseRange uint32 - - // MaxTasksInMemQueue is the maximum number of backup tasks that should - // be kept in-memory. Any more tasks will overflow to disk. - MaxTasksInMemQueue uint64 -} - // BreachRetributionBuilder is a function that can be used to construct a // BreachRetribution from a channel ID and a commitment height. type BreachRetributionBuilder func(id lnwire.ChannelID, commitHeight uint64) (*lnwallet.BreachRetribution, channeldb.ChannelType, error) -// newTowerMsg is an internal message we'll use within the TowerClient to signal +// newTowerMsg is an internal message we'll use within the client to signal // that a new tower can be considered. type newTowerMsg struct { - // addr is the tower's reachable address that we'll use to establish a - // connection with. - addr *lnwire.NetAddress + // tower holds the info about the new Tower or new tower address + // required to connect to it. + tower *Tower // errChan is the channel through which we'll send a response back to // the caller when handling their request. @@ -254,9 +106,12 @@ type newTowerMsg struct { errChan chan error } -// staleTowerMsg is an internal message we'll use within the TowerClient to +// staleTowerMsg is an internal message we'll use within the client to // signal that a tower should no longer be considered. type staleTowerMsg struct { + // id is the unique database identifier for the tower. + id wtdb.TowerID + // pubKey is the identifying public key of the watchtower. pubKey *btcec.PublicKey @@ -273,14 +128,23 @@ type staleTowerMsg struct { errChan chan error } -// TowerClient is a concrete implementation of the Client interface, offering a -// non-blocking, reliable subsystem for backing up revoked states to a specified -// private tower. -type TowerClient struct { - started sync.Once - stopped sync.Once +// clientCfg holds the configuration values required by a client. +type clientCfg struct { + *Config - cfg *Config + // Policy is the session policy the client will propose when creating + // new sessions with the tower. If the policy differs from any active + // sessions recorded in the database, those sessions will be ignored and + // new sessions will be requested immediately. + Policy wtpolicy.Policy + + getSweepScript func(lnwire.ChannelID) ([]byte, bool) +} + +// client manages backing up revoked states for all states that fall under a +// specific policy type. +type client struct { + cfg *clientCfg log btclog.Logger @@ -294,13 +158,8 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *wtdb.BackupID - closableSessionQueue *sessionCloseMinHeap - - backupMu sync.Mutex - chanInfos wtdb.ChannelInfos - statTicker *time.Ticker - stats *ClientStats + stats *clientStats newTowers chan *newTowerMsg staleTowers chan *staleTowerMsg @@ -309,28 +168,9 @@ type TowerClient struct { quit chan struct{} } -// Compile-time constraint to ensure *TowerClient implements the Client -// interface. -var _ Client = (*TowerClient)(nil) - -// New initializes a new TowerClient from the provide Config. An error is +// newClient initializes a new client from the provided clientCfg. An error is // returned if the client could not be initialized. -func New(config *Config) (*TowerClient, error) { - // Copy the config to prevent side effects from modifying both the - // internal and external version of the Config. - cfg := new(Config) - *cfg = *config - - // Set the read timeout to the default if none was provided. - if cfg.ReadTimeout <= 0 { - cfg.ReadTimeout = DefaultReadTimeout - } - - // Set the write timeout to the default if none was provided. - if cfg.WriteTimeout <= 0 { - cfg.WriteTimeout = DefaultWriteTimeout - } - +func newClient(cfg *clientCfg) (*client, error) { identifier, err := cfg.Policy.BlobType.Identifier() if err != nil { return nil, err @@ -339,11 +179,6 @@ func New(config *Config) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - chanInfos, err := cfg.DB.FetchChanInfos() - if err != nil { - return nil, err - } - queueDB := cfg.DB.GetDBQueue([]byte(identifier)) queue, err := NewDiskOverflowQueue[*wtdb.BackupID]( queueDB, cfg.MaxTasksInMemQueue, plog, @@ -352,18 +187,16 @@ func New(config *Config) (*TowerClient, error) { return nil, err } - c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: queue, - activeSessions: newSessionQueueSet(), - chanInfos: chanInfos, - closableSessionQueue: newSessionCloseMinHeap(), - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(ClientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - quit: make(chan struct{}), + c := &client{ + cfg: cfg, + log: plog, + pipeline: queue, + activeSessions: newSessionQueueSet(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(clientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + quit: make(chan struct{}), } candidateTowers := newTowerListIterator() @@ -513,251 +346,114 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, return sessions, nil } -// Start initializes the watchtower client by loading or negotiating an active +// start initializes the watchtower client by loading or negotiating an active // session and then begins processing backup tasks from the request pipeline. -func (c *TowerClient) Start() error { - var returnErr error - c.started.Do(func() { - c.log.Infof("Watchtower client starting") +func (c *client) start() error { + c.log.Infof("Watchtower client starting") - // First, restart a session queue for any sessions that have - // committed but unacked state updates. This ensures that these - // sessions will be able to flush the committed updates after a - // restart. - fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates - for _, session := range c.candidateSessions { - committedUpdates, err := fetchCommittedUpdates( - &session.ID, - ) - if err != nil { - returnErr = err - return - } - - if len(committedUpdates) > 0 { - c.log.Infof("Starting session=%s to process "+ - "%d committed backups", session.ID, - len(committedUpdates)) - - c.initActiveQueue(session, committedUpdates) - } - } - - chanSub, err := c.cfg.SubscribeChannelEvents() - if err != nil { - returnErr = err - return - } - - // Iterate over the list of registered channels and check if - // any of them can be marked as closed. - for id := range c.chanInfos { - isClosed, closedHeight, err := c.isChannelClosed(id) - if err != nil { - returnErr = err - return - } - - if !isClosed { - continue - } - - _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) - if err != nil { - c.log.Errorf("could not mark channel(%s) as "+ - "closed: %v", id, err) - - continue - } - - // Since the channel has been marked as closed, we can - // also remove it from the channel summaries map. - delete(c.chanInfos, id) - } - - // Load all closable sessions. - closableSessions, err := c.cfg.DB.ListClosableSessions() - if err != nil { - returnErr = err - return - } - - err = c.trackClosableSessions(closableSessions) - if err != nil { - returnErr = err - return - } - - c.wg.Add(1) - go c.handleChannelCloses(chanSub) - - // Subscribe to new block events. - blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( - nil, + // First, restart a session queue for any sessions that have + // committed but unacked state updates. This ensures that these + // sessions will be able to flush the committed updates after a + // restart. + fetchCommittedUpdates := c.cfg.DB.FetchSessionCommittedUpdates + for _, session := range c.candidateSessions { + committedUpdates, err := fetchCommittedUpdates( + &session.ID, ) if err != nil { - returnErr = err - return + return err } - c.wg.Add(1) - go c.handleClosableSessions(blockEvents) + if len(committedUpdates) > 0 { + c.log.Infof("Starting session=%s to process "+ + "%d committed backups", session.ID, + len(committedUpdates)) - // Now start the session negotiator, which will allow us to - // request new session as soon as the backupDispatcher starts - // up. - err = c.negotiator.Start() - if err != nil { - returnErr = err - return + c.initActiveQueue(session, committedUpdates) } - - // Start the task pipeline to which new backup tasks will be - // submitted from active links. - err = c.pipeline.Start() - if err != nil { - returnErr = err - return - } - - c.wg.Add(1) - go c.backupDispatcher() - - c.log.Infof("Watchtower client started successfully") - }) - return returnErr -} - -// Stop idempotently initiates a graceful shutdown of the watchtower client. -func (c *TowerClient) Stop() error { - var returnErr error - c.stopped.Do(func() { - c.log.Debugf("Stopping watchtower client") - - // 1. Stop the session negotiator. - err := c.negotiator.Stop() - if err != nil { - returnErr = err - } - - // 2. Stop the backup dispatcher and any other goroutines. - close(c.quit) - c.wg.Wait() - - // 3. If there was a left over 'prevTask' from the backup - // dispatcher, replay that onto the pipeline. - if c.prevTask != nil { - err = c.pipeline.QueueBackupID(c.prevTask) - if err != nil { - returnErr = err - } - } - - // 4. Shutdown all active session queues in parallel. These will - // exit once all unhandled updates have been replayed to the - // task pipeline. - c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() { - return func() { - err := s.Stop(false) - if err != nil { - c.log.Errorf("could not stop session "+ - "queue: %s: %v", s.ID(), err) - - returnErr = err - } - } - }) - - // 5. Shutdown the backup queue, which will prevent any further - // updates from being accepted. - if err = c.pipeline.Stop(); err != nil { - returnErr = err - } - - c.log.Debugf("Client successfully stopped, stats: %s", c.stats) - }) - - return returnErr -} - -// RegisterChannel persistently initializes any channel-dependent parameters -// within the client. This should be called during link startup to ensure that -// the client is able to support the link during operation. -func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // If a pkscript for this channel already exists, the channel has been - // previously registered. - if _, ok := c.chanInfos[chanID]; ok { - return nil } - // Otherwise, generate a new sweep pkscript used to sweep funds for this - // channel. - pkScript, err := c.cfg.NewAddress() + // Now start the session negotiator, which will allow us to request new + // session as soon as the backupDispatcher starts up. + err := c.negotiator.Start() if err != nil { return err } - // Persist the sweep pkscript so that restarts will not introduce - // address inflation when the channel is reregistered after a restart. - err = c.cfg.DB.RegisterChannel(chanID, pkScript) + // Start the task pipeline to which new backup tasks will be + // submitted from active links. + err = c.pipeline.Start() if err != nil { return err } - // Finally, cache the pkscript in our in-memory cache to avoid db - // lookups for the remainder of the daemon's execution. - c.chanInfos[chanID] = &wtdb.ChannelInfo{ - ClientChanSummary: wtdb.ClientChanSummary{ - SweepPkScript: pkScript, - }, - } + c.wg.Add(1) + go c.backupDispatcher() + + c.log.Infof("Watchtower client started successfully") return nil } -// BackupState initiates a request to back up a particular revoked state. If the +// stop idempotently initiates a graceful shutdown of the watchtower client. +func (c *client) stop() error { + var returnErr error + c.log.Debugf("Stopping watchtower client") + + // 1. Stop the session negotiator. + err := c.negotiator.Stop() + if err != nil { + returnErr = err + } + + // 2. Stop the backup dispatcher and any other goroutines. + close(c.quit) + c.wg.Wait() + + // 3. If there was a left over 'prevTask' from the backup + // dispatcher, replay that onto the pipeline. + if c.prevTask != nil { + err = c.pipeline.QueueBackupID(c.prevTask) + if err != nil { + returnErr = err + } + } + + // 4. Shutdown all active session queues in parallel. These will + // exit once all unhandled updates have been replayed to the + // task pipeline. + c.activeSessions.ApplyAndWait(func(s *sessionQueue) func() { + return func() { + err := s.Stop(false) + if err != nil { + c.log.Errorf("could not stop session "+ + "queue: %s: %v", s.ID(), err) + + returnErr = err + } + } + }) + + // 5. Shutdown the backup queue, which will prevent any further + // updates from being accepted. + if err = c.pipeline.Stop(); err != nil { + returnErr = err + } + + c.log.Debugf("Client successfully stopped, stats: %s", c.stats) + + return returnErr +} + +// backupState initiates a request to back up a particular revoked state. If the // method returns nil, the backup is guaranteed to be successful unless the: // - justice transaction would create dust outputs when trying to abide by the // negotiated policy, or // - breached outputs contain too little value to sweep at the target sweep // fee rate. -func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, +func (c *client) backupState(chanID *lnwire.ChannelID, stateNum uint64) error { - // Make sure that this channel is registered with the tower client. - c.backupMu.Lock() - info, ok := c.chanInfos[*chanID] - if !ok { - c.backupMu.Unlock() - - return ErrUnregisteredChannel - } - - // Ignore backups that have already been presented to the client. - var duplicate bool - info.MaxHeight.WhenSome(func(maxHeight uint64) { - if stateNum <= maxHeight { - duplicate = true - } - }) - if duplicate { - c.backupMu.Unlock() - - c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ - "height=%d", chanID, stateNum) - - return nil - } - - // This backup has a higher commit height than any known backup for this - // channel. We'll update our tip so that we won't accept it again if the - // link flaps. - c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) - c.backupMu.Unlock() - id := &wtdb.BackupID{ ChanID: *chanID, CommitHeight: stateNum, @@ -771,7 +467,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // active client's advertised policy will be ignored, but may be resumed if the // client is restarted with a matching policy. If no candidates were found, nil // is returned to signal that we need to request a new policy. -func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { +func (c *client) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. var candidateSession *ClientSession @@ -809,220 +505,15 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } -// handleChannelCloses listens for channel close events and marks channels as -// closed in the DB. -// -// NOTE: This method MUST be run as a goroutine. -func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { - defer c.wg.Done() - - c.log.Debugf("Starting channel close handler") - defer c.log.Debugf("Stopping channel close handler") - - for { - select { - case update, ok := <-chanSub.Updates(): - if !ok { - c.log.Debugf("Channel notifier has exited") - return - } - - // We only care about channel-close events. - event, ok := update.(channelnotifier.ClosedChannelEvent) - if !ok { - continue - } - - chanID := lnwire.NewChanIDFromOutPoint( - &event.CloseSummary.ChanPoint, - ) - - c.log.Debugf("Received ClosedChannelEvent for "+ - "channel: %s", chanID) - - err := c.handleClosedChannel( - chanID, event.CloseSummary.CloseHeight, - ) - if err != nil { - c.log.Errorf("Could not handle channel close "+ - "event for channel(%s): %v", chanID, - err) - } - - case <-c.quit: - return - } - } -} - -// handleClosedChannel handles the closure of a single channel. It will mark the -// channel as closed in the DB, then it will handle all the sessions that are -// now closable due to the channel closure. -func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, - closeHeight uint32) error { - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // We only care about channels registered with the tower client. - if _, ok := c.chanInfos[chanID]; !ok { - return nil - } - - c.log.Debugf("Marking channel(%s) as closed", chanID) - - sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) - if err != nil { - return fmt.Errorf("could not mark channel(%s) as closed: %w", - chanID, err) - } - - closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) - for _, sess := range sessions { - closableSessions[sess] = closeHeight - } - - c.log.Debugf("Tracking %d new closable sessions as a result of "+ - "closing channel %s", len(closableSessions), chanID) - - err = c.trackClosableSessions(closableSessions) - if err != nil { - return fmt.Errorf("could not track closable sessions: %w", err) - } - - delete(c.chanInfos, chanID) - - return nil -} - -// handleClosableSessions listens for new block notifications. For each block, -// it checks the closableSessionQueue to see if there is a closable session with -// a delete-height smaller than or equal to the new block, if there is then the -// tower is informed that it can delete the session, and then we also delete it -// from our DB. -func (c *TowerClient) handleClosableSessions( - blocksChan *chainntnfs.BlockEpochEvent) { - - defer c.wg.Done() - - c.log.Debug("Starting closable sessions handler") - defer c.log.Debug("Stopping closable sessions handler") - - for { - select { - case newBlock := <-blocksChan.Epochs: - if newBlock == nil { - return - } - - height := uint32(newBlock.Height) - for { - select { - case <-c.quit: - return - default: - } - - // If there are no closable sessions that we - // need to handle, then we are done and can - // reevaluate when the next block comes. - item := c.closableSessionQueue.Top() - if item == nil { - break - } - - // If there is closable session but the delete - // height we have set for it is after the - // current block height, then our work is done. - if item.deleteHeight > height { - break - } - - // Otherwise, we pop this item from the heap - // and handle it. - c.closableSessionQueue.Pop() - - // Stop the session and remove it from the - // in-memory set. - err := c.activeSessions.StopAndRemove( - item.sessionID, - ) - if err != nil { - c.log.Errorf("could not remove "+ - "session(%s) from in-memory "+ - "set: %v", item.sessionID, err) - - return - } - - // Fetch the session from the DB so that we can - // extract the Tower info. - sess, err := c.cfg.DB.GetClientSession( - item.sessionID, - ) - if err != nil { - c.log.Errorf("error calling "+ - "GetClientSession for "+ - "session %s: %v", - item.sessionID, err) - - continue - } - - err = c.deleteSessionFromTower(sess) - if err != nil { - c.log.Errorf("error deleting "+ - "session %s from tower: %v", - sess.ID, err) - - continue - } - - err = c.cfg.DB.DeleteSession(item.sessionID) - if err != nil { - c.log.Errorf("could not delete "+ - "session(%s) from DB: %w", - sess.ID, err) - - continue - } - } - - case <-c.quit: - return - } - } -} - -// trackClosableSessions takes in a map of session IDs to the earliest block -// height at which the session should be deleted. For each of the sessions, -// a random delay is added to the block height and the session is added to the -// closableSessionQueue. -func (c *TowerClient) trackClosableSessions( - sessions map[wtdb.SessionID]uint32) error { - - // For each closable session, add a random delay to its close - // height and add it to the closableSessionQueue. - for sID, blockHeight := range sessions { - delay, err := newRandomDelay(c.cfg.SessionCloseRange) - if err != nil { - return err - } - - deleteHeight := blockHeight + delay - - c.closableSessionQueue.Push(&sessionCloseItem{ - sessionID: sID, - deleteHeight: deleteHeight, - }) - } - - return nil +// stopAndRemoveSession stops the session with the given ID and removes it from +// the in-memory active sessions set. +func (c *client) stopAndRemoveSession(id wtdb.SessionID) error { + return c.activeSessions.StopAndRemove(id) } // deleteSessionFromTower dials the tower that we created the session with and // attempts to send the tower the DeleteSession message. -func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { +func (c *client) deleteSessionFromTower(sess *wtdb.ClientSession) error { // First, we check if we have already loaded this tower in our // candidate towers iterator. tower, err := c.candidateTowers.GetTower(sess.TowerID) @@ -1146,10 +637,10 @@ func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { // backupDispatcher processes events coming from the taskPipeline and is // responsible for detecting when the client needs to renegotiate a session to -// fulfill continuing demand. The event loop exits if the TowerClient is quit. +// fulfill continuing demand. The event loop exits if the client is quit. // // NOTE: This method MUST be run as a goroutine. -func (c *TowerClient) backupDispatcher() { +func (c *client) backupDispatcher() { defer c.wg.Done() c.log.Tracef("Starting backup dispatcher") @@ -1188,7 +679,7 @@ func (c *TowerClient) backupDispatcher() { // its corresponding sessions, if any, as new // candidates. case msg := <-c.newTowers: - msg.errChan <- c.handleNewTower(msg) + msg.errChan <- c.handleNewTower(msg.tower) // A tower has been requested to be removed. We'll // only allow removal of it if the address in question @@ -1272,7 +763,7 @@ func (c *TowerClient) backupDispatcher() { // its corresponding sessions, if any, as new // candidates. case msg := <-c.newTowers: - msg.errChan <- c.handleNewTower(msg) + msg.errChan <- c.handleNewTower(msg.tower) // A tower has been removed, so we'll remove certain // information that's persisted and also in our @@ -1295,20 +786,16 @@ func (c *TowerClient) backupDispatcher() { // sessionQueue hasn't been exhausted before proceeding to the next task. Tasks // that are rejected because the active sessionQueue is full will be cached as // the prevTask, and should be reprocessed after obtaining a new sessionQueue. -func (c *TowerClient) processTask(task *wtdb.BackupID) { - c.backupMu.Lock() - summary, ok := c.chanInfos[task.ChanID] +func (c *client) processTask(task *wtdb.BackupID) { + script, ok := c.cfg.getSweepScript(task.ChanID) if !ok { - c.backupMu.Unlock() - log.Infof("not processing task for unregistered channel: %s", task.ChanID) return } - c.backupMu.Unlock() - backupTask := newBackupTask(*task, summary.SweepPkScript) + backupTask := newBackupTask(*task, script) status, accepted := c.sessionQueue.AcceptTask(backupTask) if accepted { @@ -1323,7 +810,7 @@ func (c *TowerClient) processTask(task *wtdb.BackupID) { // prevTask is always removed as a result of this call. The client's // sessionQueue will be removed if accepting the task left the sessionQueue in // an exhausted state. -func (c *TowerClient) taskAccepted(task *wtdb.BackupID, +func (c *client) taskAccepted(task *wtdb.BackupID, newStatus sessionQueueStatus) { c.log.Infof("Queued %v successfully for session %v", task, @@ -1361,7 +848,7 @@ func (c *TowerClient) taskAccepted(task *wtdb.BackupID, // exhausted and not shutting down, the client marks the task as ineligible, as // this implies we couldn't construct a valid justice transaction given the // session's policy. -func (c *TowerClient) taskRejected(task *wtdb.BackupID, +func (c *client) taskRejected(task *wtdb.BackupID, curStatus sessionQueueStatus) { switch curStatus { @@ -1422,7 +909,7 @@ func (c *TowerClient) taskRejected(task *wtdb.BackupID, // dial connects the peer at addr using privKey as our secret key for the // connection. The connection will use the configured Net's resolver to resolve // the address for either Tor or clear net connections. -func (c *TowerClient) dial(localKey keychain.SingleKeyECDH, +func (c *client) dial(localKey keychain.SingleKeyECDH, addr *lnwire.NetAddress) (wtserver.Peer, error) { return c.cfg.AuthDial(localKey, addr, c.cfg.Dial) @@ -1432,7 +919,7 @@ func (c *TowerClient) dial(localKey keychain.SingleKeyECDH, // error is returned if a message is not received before the server's read // timeout, the read off the wire failed, or the message could not be // deserialized. -func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) { +func (c *client) readMessage(peer wtserver.Peer) (wtwire.Message, error) { // Set a read timeout to ensure we drop the connection if nothing is // received in a timely manner. err := peer.SetReadDeadline(time.Now().Add(c.cfg.ReadTimeout)) @@ -1466,7 +953,7 @@ func (c *TowerClient) readMessage(peer wtserver.Peer) (wtwire.Message, error) { } // sendMessage sends a watchtower wire message to the target peer. -func (c *TowerClient) sendMessage(peer wtserver.Peer, +func (c *client) sendMessage(peer wtserver.Peer, msg wtwire.Message) error { // Encode the next wire message into the buffer. @@ -1500,7 +987,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *ClientSession, +func (c *client) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ @@ -1522,7 +1009,7 @@ func (c *TowerClient) newSessionQueue(s *ClientSession, // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise, a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *ClientSession, +func (c *client) getOrInitActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { if sq, ok := c.activeSessions.Get(s.ID); ok { @@ -1536,7 +1023,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *ClientSession, +func (c *client) initActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { // Initialize the session queue, providing it with all the resources it @@ -1552,32 +1039,16 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } -// isChanClosed can be used to check if the channel with the given ID has been -// closed. If it has been, the block height in which its closing transaction was -// mined will also be returned. -func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, - error) { - - chanSum, err := c.cfg.FetchClosedChannel(id) - if errors.Is(err, channeldb.ErrClosedChannelNotFound) { - return false, 0, nil - } else if err != nil { - return false, 0, err - } - - return true, chanSum.CloseHeight, nil -} - -// AddTower adds a new watchtower reachable at the given address and considers +// addTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and // backups. -func (c *TowerClient) AddTower(addr *lnwire.NetAddress) error { +func (c *client) addTower(tower *Tower) error { errChan := make(chan error, 1) select { case c.newTowers <- &newTowerMsg{ - addr: addr, + tower: tower, errChan: errChan, }: case <-c.pipeline.quit: @@ -1595,20 +1066,7 @@ func (c *TowerClient) AddTower(addr *lnwire.NetAddress) error { // handleNewTower handles a request for a new tower to be added. If the tower // already exists, then its corresponding sessions, if any, will be set // considered as candidates. -func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { - // We'll start by updating our persisted state, followed by our - // in-memory state, with the new tower. This might not actually be a new - // tower, but it might include a new address at which it can be reached. - dbTower, err := c.cfg.DB.CreateTower(msg.addr) - if err != nil { - return err - } - - tower, err := NewTowerFromDBTower(dbTower) - if err != nil { - return err - } - +func (c *client) handleNewTower(tower *Tower) error { c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. @@ -1628,17 +1086,18 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { return nil } -// RemoveTower removes a watchtower from being considered for future session +// removeTower removes a watchtower from being considered for future session // negotiations and from being used for any subsequent backups until it's added // again. If an address is provided, then this call only serves as a way of // removing the address from the watchtower instead. -func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, +func (c *client) removeTower(id wtdb.TowerID, pubKey *btcec.PublicKey, addr net.Addr) error { errChan := make(chan error, 1) select { case c.staleTowers <- &staleTowerMsg{ + id: id, pubKey: pubKey, addr: addr, errChan: errChan, @@ -1659,10 +1118,9 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, // none of the tower's sessions have pending updates, then they will become // inactive and removed as candidates. If the active session queue corresponds // to any of these sessions, a new one will be negotiated. -func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { - // We'll load the tower before potentially removing it in order to - // retrieve its ID within the database. - dbTower, err := c.cfg.DB.LoadTower(msg.pubKey) +func (c *client) handleStaleTower(msg *staleTowerMsg) error { + // We'll first update our in-memory state. + err := c.candidateTowers.RemoveCandidate(msg.id, msg.addr) if err != nil { return err } @@ -1670,19 +1128,14 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // If an address was provided, then we're only meant to remove the // address from the tower. if msg.addr != nil { - return c.removeTowerAddr(dbTower, msg.addr) + return nil } // Otherwise, the tower should no longer be used for future session - // negotiations and backups. First, we'll update our in-memory state - // with the stale tower. - err = c.candidateTowers.RemoveCandidate(dbTower.ID, nil) - if err != nil { - return err - } + // negotiations and backups. pubKey := msg.pubKey.SerializeCompressed() - sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID) + sessions, err := c.cfg.DB.ListClientSessions(&msg.id) if err != nil { return fmt.Errorf("unable to retrieve sessions for tower %x: "+ "%v", pubKey, err) @@ -1709,54 +1162,16 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { } } - // Finally, we will update our persisted state with the stale tower. - return c.cfg.DB.RemoveTower(msg.pubKey, nil) -} - -// removeTowerAddr removes the given address from the tower. -func (c *TowerClient) removeTowerAddr(tower *wtdb.Tower, addr net.Addr) error { - if addr == nil { - return fmt.Errorf("an address must be provided") - } - - // We'll first update our in-memory state followed by our persisted - // state with the stale tower. The removal of the tower address from - // the in-memory state will fail if the address is currently being used - // for a session negotiation. - err := c.candidateTowers.RemoveCandidate(tower.ID, addr) - if err != nil { - return err - } - - err = c.cfg.DB.RemoveTower(tower.IdentityKey, addr) - if err != nil { - // If the persisted state update fails, re-add the address to - // our in-memory state. - tower, newTowerErr := NewTowerFromDBTower(tower) - if newTowerErr != nil { - log.Errorf("could not create new in-memory tower: %v", - newTowerErr) - } else { - c.candidateTowers.AddCandidate(tower) - } - - return err - } - return nil } -// RegisteredTowers retrieves the list of watchtowers registered with the +// registeredTowers retrieves the list of watchtowers registered with the // client. -func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( - []*RegisteredTower, error) { - - // Retrieve all of our towers along with all of our sessions. - towers, err := c.cfg.DB.ListTowers() - if err != nil { - return nil, err - } +func (c *client) registeredTowers(towers []*wtdb.Tower, + opts ...wtdb.ClientSessionListOption) ([]*RegisteredTower, error) { + // Generate a filter that will fetch all the client's sessions + // regardless of if they are active or not. opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false))) clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...) @@ -1764,8 +1179,8 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( return nil, err } - // Construct a lookup map that coalesces all of the sessions for a - // specific watchtower. + // Construct a lookup map that coalesces all the sessions for a specific + // watchtower. towerSessions := make( map[wtdb.TowerID]map[wtdb.SessionID]*wtdb.ClientSession, ) @@ -1791,15 +1206,11 @@ func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( return registeredTowers, nil } -// LookupTower retrieves a registered watchtower through its public key. -func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, +// lookupTower retrieves the info of sessions held with the given tower handled +// by this client. +func (c *client) lookupTower(tower *wtdb.Tower, opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) { - tower, err := c.cfg.DB.LoadTower(pubKey) - if err != nil { - return nil, err - } - opts = append(opts, wtdb.WithPreEvalFilterFn(c.genSessionFilter(false))) towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...) @@ -1814,20 +1225,20 @@ func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey, }, nil } -// Stats returns the in-memory statistics of the client since startup. -func (c *TowerClient) Stats() ClientStats { - return c.stats.Copy() +// getStats returns the in-memory statistics of the client since startup. +func (c *client) getStats() ClientStats { + return c.stats.getStatsCopy() } -// Policy returns the active client policy configuration. -func (c *TowerClient) Policy() wtpolicy.Policy { +// policy returns the active client policy configuration. +func (c *client) policy() wtpolicy.Policy { return c.cfg.Policy } // logMessage writes information about a message received from a remote peer, // using directional prepositions to signal whether the message was sent or // received. -func (c *TowerClient) logMessage( +func (c *client) logMessage( peer wtserver.Peer, msg wtwire.Message, read bool) { var action = "Received" diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 51dd16e09..95e2664eb 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -395,15 +395,16 @@ func (c *mockChannel) getState( } type testHarness struct { - t *testing.T - cfg harnessCfg - signer *wtmock.MockSigner - capacity lnwire.MilliSatoshi - clientDB *wtdb.ClientDB - clientCfg *wtclient.Config - client wtclient.Client - server *serverHarness - net *mockNet + t *testing.T + cfg harnessCfg + signer *wtmock.MockSigner + capacity lnwire.MilliSatoshi + clientMgr *wtclient.Manager + clientDB *wtdb.ClientDB + clientCfg *wtclient.Config + clientPolicy wtpolicy.Policy + server *serverHarness + net *mockNet blockEvents *mockBlockSub height int32 @@ -486,6 +487,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil } + h.clientPolicy = cfg.policy h.clientCfg = &wtclient.Config{ Signer: signer, SubscribeChannelEvents: func() (subscribe.Subscription, error) { @@ -497,7 +499,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { DB: clientDB, AuthDial: mockNet.AuthDial, SecretKeyRing: wtmock.NewSecretKeyRing(), - Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, @@ -525,7 +526,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { h.startClient() t.Cleanup(func() { - require.NoError(t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) require.NoError(t, h.clientDB.Close()) }) @@ -559,10 +560,10 @@ func (h *testHarness) startClient() { Address: towerTCPAddr, } - h.client, err = wtclient.New(h.clientCfg) + h.clientMgr, err = wtclient.NewManager(h.clientCfg, h.clientPolicy) require.NoError(h.t, err) - require.NoError(h.t, h.client.Start()) - require.NoError(h.t, h.client.AddTower(towerAddr)) + require.NoError(h.t, h.clientMgr.Start()) + require.NoError(h.t, h.clientMgr.AddTower(towerAddr)) } // chanIDFromInt creates a unique channel id given a unique integral id. @@ -663,7 +664,7 @@ func (h *testHarness) registerChannel(id uint64) { h.t.Helper() chanID := chanIDFromInt(id) - err := h.client.RegisterChannel(chanID) + err := h.clientMgr.RegisterChannel(chanID, channeldb.SingleFunderBit) require.NoError(h.t, err) } @@ -704,8 +705,8 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution.RevokedStateNum) - require.ErrorIs(h.t, expErr, err) + err := h.clientMgr.BackupState(&chanID, retribution.RevokedStateNum) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -752,7 +753,7 @@ func (h *testHarness) recvPayments(id, from, to uint64, func (h *testHarness) addTower(addr *lnwire.NetAddress) { h.t.Helper() - err := h.client.AddTower(addr) + err := h.clientMgr.AddTower(addr) require.NoError(h.t, err) } @@ -761,7 +762,7 @@ func (h *testHarness) addTower(addr *lnwire.NetAddress) { func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { h.t.Helper() - err := h.client.RemoveTower(pubKey, addr) + err := h.clientMgr.RemoveTower(pubKey, addr) require.NoError(h.t, err) } @@ -1123,7 +1124,7 @@ var clientTests = []clientTest{ ) // Stop the client, subsequent backups should fail. - h.client.Stop() + require.NoError(h.t, h.clientMgr.Stop()) // Advance the channel and try to back up the states. We // expect ErrClientExiting to be returned from @@ -1238,7 +1239,7 @@ var clientTests = []clientTest{ // Stop the client to abort the state updates it has // queued. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) // Restart the server and allow it to ack the updates // after the client retransmits the unacked update. @@ -1249,6 +1250,7 @@ var clientTests = []clientTest{ // Restart the client and allow it to process the // committed update. h.startClient() + h.registerChannel(chanID) // Wait for the committed update to be accepted by the // tower. @@ -1433,7 +1435,7 @@ var clientTests = []clientTest{ h.server.waitForUpdates(nil, waitTime) // Stop the client since it has queued backups. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) // Restart the server and allow it to ack session // creation. @@ -1452,9 +1454,7 @@ var clientTests = []clientTest{ // Assert that the server has updates for the clients // most recent policy. - h.server.assertUpdatesForPolicy( - hints, h.clientCfg.Policy, - ) + h.server.assertUpdatesForPolicy(hints, h.clientPolicy) }, }, { @@ -1485,7 +1485,7 @@ var clientTests = []clientTest{ h.server.waitForUpdates(nil, waitTime) // Stop the client since it has queued backups. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) // Restart the server and allow it to ack session // creation. @@ -1496,7 +1496,7 @@ var clientTests = []clientTest{ // Restart the client with a new policy, which will // immediately try to overwrite the prior session with // the old policy. - h.clientCfg.Policy.SweepFeeRate *= 2 + h.clientPolicy.SweepFeeRate *= 2 h.startClient() // Wait for all the updates to be populated in the @@ -1505,9 +1505,7 @@ var clientTests = []clientTest{ // Assert that the server has updates for the clients // most recent policy. - h.server.assertUpdatesForPolicy( - hints, h.clientCfg.Policy, - ) + h.server.assertUpdatesForPolicy(hints, h.clientPolicy) }, }, { @@ -1541,7 +1539,7 @@ var clientTests = []clientTest{ h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Stop the client, which should have no more backups. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) // Record the policy that the first half was stored // under. We'll expect the second half to also be @@ -1549,11 +1547,12 @@ var clientTests = []clientTest{ // adjusting the MaxUpdates. The client should detect // that the two policies have equivalent TxPolicies and // continue using the first. - expPolicy := h.clientCfg.Policy + expPolicy := h.clientPolicy // Restart the client with a new policy. - h.clientCfg.Policy.MaxUpdates = 20 + h.clientPolicy.MaxUpdates = 20 h.startClient() + h.registerChannel(chanID) // Now, queue the second half of the retributions. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -1602,8 +1601,9 @@ var clientTests = []clientTest{ // Restart the client, so we can ensure the deduping is // maintained across restarts. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) h.startClient() + h.registerChannel(chanID) // Try to back up the full range of retributions. Only // the second half should actually be sent. @@ -1713,11 +1713,11 @@ var clientTests = []clientTest{ h.server.addr = towerAddr // Add the new tower address to the client. - err = h.client.AddTower(towerAddr) + err = h.clientMgr.AddTower(towerAddr) require.NoError(h.t, err) // Remove the old tower address from the client. - err = h.client.RemoveTower( + err = h.clientMgr.RemoveTower( towerAddr.IdentityKey, oldAddr, ) require.NoError(h.t, err) @@ -1751,7 +1751,7 @@ var clientTests = []clientTest{ // negotiation with the server will be in progress, so // the client should be able to remove the server. err := wait.NoError(func() error { - return h.client.RemoveTower( + return h.clientMgr.RemoveTower( h.server.addr.IdentityKey, nil, ) }, waitTime) @@ -1794,11 +1794,11 @@ var clientTests = []clientTest{ require.NoError(h.t, h.server.server.Start()) // Re-add the server to the client - err = h.client.AddTower(h.server.addr) + err = h.clientMgr.AddTower(h.server.addr) require.NoError(h.t, err) // Also add the new tower address. - err = h.client.AddTower(towerAddr) + err = h.clientMgr.AddTower(towerAddr) require.NoError(h.t, err) // Assert that if the client attempts to remove the @@ -1806,7 +1806,7 @@ var clientTests = []clientTest{ // address currently being locked for session // negotiation. err = wait.Predicate(func() bool { - err = h.client.RemoveTower( + err = h.clientMgr.RemoveTower( h.server.addr.IdentityKey, h.server.addr.Address, ) @@ -1817,7 +1817,7 @@ var clientTests = []clientTest{ // Assert that the second address can be removed since // it is not being used for session negotiation. err = wait.NoError(func() error { - return h.client.RemoveTower( + return h.clientMgr.RemoveTower( h.server.addr.IdentityKey, towerTCPAddr, ) }, waitTime) @@ -1829,7 +1829,7 @@ var clientTests = []clientTest{ // Assert that the client can now remove the first // address. err = wait.NoError(func() error { - return h.client.RemoveTower( + return h.clientMgr.RemoveTower( h.server.addr.IdentityKey, nil, ) }, waitTime) @@ -1882,7 +1882,7 @@ var clientTests = []clientTest{ require.False(h.t, h.isSessionClosable(sessionIDs[0])) // Restart the client. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) h.startClient() // The session should now have been marked as closable. @@ -2069,7 +2069,7 @@ var clientTests = []clientTest{ h.server.waitForUpdates(hints[:numUpdates/2], waitTime) // Now stop the client and reset its database. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) db := newClientDB(h.t) h.clientDB = db @@ -2122,9 +2122,10 @@ var clientTests = []clientTest{ h.backupStates(chanID, 0, numUpdates/2, nil) // Restart the Client. And also now start the server. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) h.server.start() h.startClient() + h.registerChannel(chanID) // Back up a few more states. h.backupStates(chanID, numUpdates/2, numUpdates, nil) @@ -2222,7 +2223,7 @@ var clientTests = []clientTest{ // Now we can remove the old one. err := wait.Predicate(func() bool { - err := h.client.RemoveTower( + err := h.clientMgr.RemoveTower( h.server.addr.IdentityKey, nil, ) @@ -2308,7 +2309,7 @@ var clientTests = []clientTest{ require.NoError(h.t, err) // Now remove the tower. - err = h.client.RemoveTower( + err = h.clientMgr.RemoveTower( h.server.addr.IdentityKey, nil, ) require.NoError(h.t, err) @@ -2395,11 +2396,11 @@ var clientTests = []clientTest{ // Now restart the client. This ensures that the // updates are no longer in the pending queue. - require.NoError(h.t, h.client.Stop()) + require.NoError(h.t, h.clientMgr.Stop()) h.startClient() // Now remove the tower. - err = h.client.RemoveTower( + err = h.clientMgr.RemoveTower( h.server.addr.IdentityKey, nil, ) require.NoError(h.t, err) @@ -2519,7 +2520,7 @@ var clientTests = []clientTest{ // Wait for channel to be "unregistered". chanID := chanIDFromInt(chanIDInt) err = wait.Predicate(func() bool { - err := h.client.BackupState(&chanID, 0) + err := h.clientMgr.BackupState(&chanID, 0) return errors.Is( err, wtclient.ErrUnregisteredChannel, diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go new file mode 100644 index 000000000..73f259085 --- /dev/null +++ b/watchtower/wtclient/manager.go @@ -0,0 +1,902 @@ +package wtclient + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" + "github.com/lightningnetwork/lnd/tor" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// ClientManager is the primary interface used by the daemon to control a +// client's lifecycle and backup revoked states. +type ClientManager interface { + // AddTower adds a new watchtower reachable at the given address and + // considers it for new sessions. If the watchtower already exists, then + // any new addresses included will be considered when dialing it for + // session negotiations and backups. + AddTower(*lnwire.NetAddress) error + + // RemoveTower removes a watchtower from being considered for future + // session negotiations and from being used for any subsequent backups + // until it's added again. If an address is provided, then this call + // only serves as a way of removing the address from the watchtower + // instead. + RemoveTower(*btcec.PublicKey, net.Addr) error + + // Stats returns the in-memory statistics of the client since startup. + Stats() ClientStats + + // Policy returns the active client policy configuration. + Policy(blob.Type) (wtpolicy.Policy, error) + + // RegisteredTowers retrieves the list of watchtowers registered with + // the client. It returns a set of registered towers per client policy + // type. + RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( + map[blob.Type][]*RegisteredTower, error) + + // LookupTower retrieves a registered watchtower through its public key. + LookupTower(*btcec.PublicKey, ...wtdb.ClientSessionListOption) ( + map[blob.Type]*RegisteredTower, error) + + // RegisterChannel persistently initializes any channel-dependent + // parameters within the client. This should be called during link + // startup to ensure that the client is able to support the link during + // operation. + RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error + + // BackupState initiates a request to back up a particular revoked + // state. If the method returns nil, the backup is guaranteed to be + // successful unless the justice transaction would create dust outputs + // when trying to abide by the negotiated policy. + BackupState(chanID *lnwire.ChannelID, stateNum uint64) error +} + +// Config provides the client with access to the resources it requires to +// perform its duty. All nillable fields must be non-nil for the tower to be +// initialized properly. +type Config struct { + // Signer provides access to the wallet so that the client can sign + // justice transactions that spend from a remote party's commitment + // transaction. + Signer input.Signer + + // SubscribeChannelEvents can be used to subscribe to channel event + // notifications. + SubscribeChannelEvents func() (subscribe.Subscription, error) + + // FetchClosedChannel can be used to fetch the info about a closed + // channel. If the channel is not found or not yet closed then + // channeldb.ErrClosedChannelNotFound will be returned. + FetchClosedChannel func(cid lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) + + // ChainNotifier can be used to subscribe to block notifications. + ChainNotifier chainntnfs.ChainNotifier + + // BuildBreachRetribution is a function closure that allows the client + // fetch the breach retribution info for a certain channel at a certain + // revoked commitment height. + BuildBreachRetribution BreachRetributionBuilder + + // NewAddress generates a new on-chain sweep pkscript. + NewAddress func() ([]byte, error) + + // SecretKeyRing is used to derive the session keys used to communicate + // with the tower. The client only stores the KeyLocators internally so + // that we never store private keys on disk. + SecretKeyRing ECDHKeyRing + + // Dial connects to an addr using the specified net and returns the + // connection object. + Dial tor.DialFunc + + // AuthDialer establishes a brontide connection over an onion or clear + // network. + AuthDial AuthDialer + + // DB provides access to the client's stable storage medium. + DB DB + + // ChainHash identifies the chain that the client is on and for which + // the tower must be watching to monitor for breaches. + ChainHash chainhash.Hash + + // ReadTimeout is the duration we will wait during a read before + // breaking out of a blocking read. If the value is less than or equal + // to zero, the default will be used instead. + ReadTimeout time.Duration + + // WriteTimeout is the duration we will wait during a write before + // breaking out of a blocking write. If the value is less than or equal + // to zero, the default will be used instead. + WriteTimeout time.Duration + + // MinBackoff defines the initial backoff applied to connections with + // watchtowers. Subsequent backoff durations will grow exponentially up + // until MaxBackoff. + MinBackoff time.Duration + + // MaxBackoff defines the maximum backoff applied to connections with + // watchtowers. If the exponential backoff produces a timeout greater + // than this value, the backoff will be clamped to MaxBackoff. + MaxBackoff time.Duration + + // SessionCloseRange is the range over which we will generate a random + // number of blocks to delay closing a session after its last channel + // has been closed. + SessionCloseRange uint32 + + // MaxTasksInMemQueue is the maximum number of backup tasks that should + // be kept in-memory. Any more tasks will overflow to disk. + MaxTasksInMemQueue uint64 +} + +// Manager manages the various tower clients that are active. A client is +// required for each different commitment transaction type. The Manager acts as +// a tower client multiplexer. +type Manager struct { + started sync.Once + stopped sync.Once + + cfg *Config + + clients map[blob.Type]*client + clientsMu sync.Mutex + + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos + chanBlobType map[lnwire.ChannelID]blob.Type + + closableSessionQueue *sessionCloseMinHeap + + wg sync.WaitGroup + quit chan struct{} +} + +var _ ClientManager = (*Manager)(nil) + +// NewManager constructs a new Manager. +func NewManager(config *Config, policies ...wtpolicy.Policy) (*Manager, error) { + // Copy the config to prevent side effects from modifying both the + // internal and external version of the Config. + cfg := *config + + // Set the read timeout to the default if none was provided. + if cfg.ReadTimeout <= 0 { + cfg.ReadTimeout = DefaultReadTimeout + } + + // Set the write timeout to the default if none was provided. + if cfg.WriteTimeout <= 0 { + cfg.WriteTimeout = DefaultWriteTimeout + } + + chanInfos, err := cfg.DB.FetchChanInfos() + if err != nil { + return nil, err + } + + m := &Manager{ + cfg: &cfg, + clients: make(map[blob.Type]*client), + chanBlobType: make(map[lnwire.ChannelID]blob.Type), + chanInfos: chanInfos, + closableSessionQueue: newSessionCloseMinHeap(), + quit: make(chan struct{}), + } + + for _, policy := range policies { + if err = policy.Validate(); err != nil { + return nil, err + } + + if err = m.newClient(policy); err != nil { + return nil, err + } + } + + return m, nil +} + +// newClient constructs a new client and adds it to the set of clients that +// the Manager is keeping track of. +func (m *Manager) newClient(policy wtpolicy.Policy) error { + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + _, ok := m.clients[policy.BlobType] + if ok { + return fmt.Errorf("a client with blob type %s has "+ + "already been registered", policy.BlobType) + } + + cfg := &clientCfg{ + Config: m.cfg, + Policy: policy, + getSweepScript: m.getSweepScript, + } + + client, err := newClient(cfg) + if err != nil { + return err + } + + m.clients[policy.BlobType] = client + + return nil +} + +// Start starts all the clients that have been registered with the Manager. +func (m *Manager) Start() error { + var returnErr error + m.started.Do(func() { + chanSub, err := m.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + + return + } + + // Iterate over the list of registered channels and check if any + // of them can be marked as closed. + for id := range m.chanInfos { + isClosed, closedHeight, err := m.isChannelClosed(id) + if err != nil { + returnErr = err + + return + } + + if !isClosed { + continue + } + + _, err = m.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(m.chanInfos, id) + } + + // Load all closable sessions. + closableSessions, err := m.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + + return + } + + err = m.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleChannelCloses(chanSub) + + // Subscribe to new block events. + blockEvents, err := m.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + + return + } + + m.wg.Add(1) + go m.handleClosableSessions(blockEvents) + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + for _, client := range m.clients { + if err := client.start(); err != nil { + returnErr = err + return + } + } + }) + + return returnErr +} + +// Stop stops all the clients that the Manger is managing. +func (m *Manager) Stop() error { + var returnErr error + m.stopped.Do(func() { + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + close(m.quit) + m.wg.Wait() + + for _, client := range m.clients { + if err := client.stop(); err != nil { + returnErr = err + } + } + }) + + return returnErr +} + +// AddTower adds a new watchtower reachable at the given address and considers +// it for new sessions. If the watchtower already exists, then any new addresses +// included will be considered when dialing it for session negotiations and +// backups. +func (m *Manager) AddTower(address *lnwire.NetAddress) error { + // We'll start by updating our persisted state, followed by the + // in-memory state of each client, with the new tower. This might not + // actually be a new tower, but it might include a new address at which + // it can be reached. + dbTower, err := m.cfg.DB.CreateTower(address) + if err != nil { + return err + } + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + for blobType, client := range m.clients { + clientType, err := blobType.Identifier() + if err != nil { + return err + } + + if err := client.addTower(tower); err != nil { + return fmt.Errorf("could not add tower(%x) to the %s "+ + "tower client: %w", + tower.IdentityKey.SerializeCompressed(), + clientType, err) + } + } + + return nil +} + +// RemoveTower removes a watchtower from being considered for future session +// negotiations and from being used for any subsequent backups until it's added +// again. If an address is provided, then this call only serves as a way of +// removing the address from the watchtower instead. +func (m *Manager) RemoveTower(key *btcec.PublicKey, addr net.Addr) error { + // We'll load the tower before potentially removing it in order to + // retrieve its ID within the database. + dbTower, err := m.cfg.DB.LoadTower(key) + if err != nil { + return err + } + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + for _, client := range m.clients { + err := client.removeTower(dbTower.ID, key, addr) + if err != nil { + return err + } + } + + if err := m.cfg.DB.RemoveTower(key, addr); err != nil { + // If the persisted state update fails, re-add the address to + // our client's in-memory state. + tower, newTowerErr := NewTowerFromDBTower(dbTower) + if newTowerErr != nil { + log.Errorf("could not create new in-memory tower: %v", + newTowerErr) + + return err + } + + for _, client := range m.clients { + addTowerErr := client.addTower(tower) + if addTowerErr != nil { + log.Errorf("could not re-add tower: %v", + addTowerErr) + } + } + + return err + } + + return nil +} + +// Stats returns the in-memory statistics of the clients managed by the Manager +// since startup. +func (m *Manager) Stats() ClientStats { + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + var resp ClientStats + for _, client := range m.clients { + stats := client.getStats() + resp.NumTasksAccepted += stats.NumTasksAccepted + resp.NumTasksIneligible += stats.NumTasksIneligible + resp.NumTasksPending += stats.NumTasksPending + resp.NumSessionsAcquired += stats.NumSessionsAcquired + resp.NumSessionsExhausted += stats.NumSessionsExhausted + } + + return resp +} + +// RegisteredTowers retrieves the list of watchtowers being used by the various +// clients. +func (m *Manager) RegisteredTowers(opts ...wtdb.ClientSessionListOption) ( + map[blob.Type][]*RegisteredTower, error) { + + towers, err := m.cfg.DB.ListTowers() + if err != nil { + return nil, err + } + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + resp := make(map[blob.Type][]*RegisteredTower) + for _, client := range m.clients { + towers, err := client.registeredTowers(towers, opts...) + if err != nil { + return nil, err + } + + resp[client.policy().BlobType] = towers + } + + return resp, nil +} + +// LookupTower retrieves a registered watchtower through its public key. +func (m *Manager) LookupTower(key *btcec.PublicKey, + opts ...wtdb.ClientSessionListOption) (map[blob.Type]*RegisteredTower, + error) { + + tower, err := m.cfg.DB.LoadTower(key) + if err != nil { + return nil, err + } + + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + resp := make(map[blob.Type]*RegisteredTower) + for _, client := range m.clients { + tower, err := client.lookupTower(tower, opts...) + if err != nil { + return nil, err + } + + resp[client.policy().BlobType] = tower + } + + return resp, nil +} + +// Policy returns the active client policy configuration for the client using +// the given blob type. +func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) { + m.clientsMu.Lock() + defer m.clientsMu.Unlock() + + var policy wtpolicy.Policy + client, ok := m.clients[blobType] + if !ok { + return policy, fmt.Errorf("no client for the given blob type") + } + + return client.policy(), nil +} + +// RegisterChannel persistently initializes any channel-dependent parameters +// within the client. This should be called during link startup to ensure that +// the client is able to support the link during operation. +func (m *Manager) RegisterChannel(id lnwire.ChannelID, + chanType channeldb.ChannelType) error { + + blobType := blob.TypeAltruistCommit + if chanType.HasAnchors() { + blobType = blob.TypeAltruistAnchorCommit + } + + m.clientsMu.Lock() + if _, ok := m.clients[blobType]; !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // If a pkscript for this channel already exists, the channel has been + // previously registered. + if _, ok := m.chanInfos[id]; ok { + // Keep track of which blob type this channel will use for + // updates. + m.chanBlobType[id] = blobType + + return nil + } + + // Otherwise, generate a new sweep pkscript used to sweep funds for this + // channel. + pkScript, err := m.cfg.NewAddress() + if err != nil { + return err + } + + // Persist the sweep pkscript so that restarts will not introduce + // address inflation when the channel is reregistered after a restart. + err = m.cfg.DB.RegisterChannel(id, pkScript) + if err != nil { + return err + } + + // Finally, cache the pkscript in our in-memory cache to avoid db + // lookups for the remainder of the daemon's execution. + m.chanInfos[id] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, + } + + // Keep track of which blob type this channel will use for updates. + m.chanBlobType[id] = blobType + + return nil +} + +// BackupState initiates a request to back up a particular revoked state. If the +// method returns nil, the backup is guaranteed to be successful unless the +// justice transaction would create dust outputs when trying to abide by the +// negotiated policy. +func (m *Manager) BackupState(chanID *lnwire.ChannelID, stateNum uint64) error { + select { + case <-m.quit: + return ErrClientExiting + default: + } + + // Make sure that this channel is registered with the tower client. + m.backupMu.Lock() + info, ok := m.chanInfos[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + + // Ignore backups that have already been presented to the client. + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { + m.backupMu.Unlock() + + log.Debugf("Ignoring duplicate backup for chanid=%v at "+ + "height=%d", chanID, stateNum) + + return nil + } + + // This backup has a higher commit height than any known backup for this + // channel. We'll update our tip so that we won't accept it again if the + // link flaps. + m.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) + + blobType, ok := m.chanBlobType[*chanID] + if !ok { + m.backupMu.Unlock() + + return ErrUnregisteredChannel + } + m.backupMu.Unlock() + + m.clientsMu.Lock() + client, ok := m.clients[blobType] + if !ok { + m.clientsMu.Unlock() + + return fmt.Errorf("no client registered for blob type %s", + blobType) + } + m.clientsMu.Unlock() + + return client.backupState(chanID, stateNum) +} + +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (m *Manager) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := m.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (m *Manager) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(m.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + m.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (m *Manager) handleChannelCloses(chanSub subscribe.Subscription) { + defer m.wg.Done() + + log.Debugf("Starting channel close handler") + defer log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := m.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-m.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. +func (m *Manager) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + m.backupMu.Lock() + defer m.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := m.chanInfos[chanID]; !ok { + return nil + } + + log.Debugf("Marking channel(%s) as closed", chanID) + + sessions, err := m.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = m.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + + delete(m.chanInfos, chanID) + + return nil +} + +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (m *Manager) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer m.wg.Done() + + log.Debug("Starting closable sessions handler") + defer log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-m.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := m.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + m.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := m.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + // get appropriate client. + m.clientsMu.Lock() + client, ok := m.clients[sess.Policy.BlobType] + if !ok { + m.clientsMu.Unlock() + log.Errorf("no client currently " + + "active for the session type") + + continue + } + m.clientsMu.Unlock() + + clientName, err := client.policy().BlobType. + Identifier() + if err != nil { + log.Errorf("could not get client "+ + "identifier: %v", err) + + continue + } + + // Stop the session and remove it from the + // in-memory set. + err = client.stopAndRemoveSession( + item.sessionID, + ) + if err != nil { + log.Errorf("could not remove "+ + "session(%s) from in-memory "+ + "set of the %s client: %v", + item.sessionID, clientName, err) + + continue + } + + err = client.deleteSessionFromTower(sess) + if err != nil { + log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = m.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-m.quit: + return + } + } +} + +func (m *Manager) getSweepScript(id lnwire.ChannelID) ([]byte, bool) { + m.backupMu.Lock() + defer m.backupMu.Unlock() + + summary, ok := m.chanInfos[id] + if !ok { + return nil, false + } + + return summary.SweepPkScript, true +} diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index 27c36c6fe..1bd7a4ddb 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -94,7 +94,7 @@ type sessionQueueConfig struct { // sessionQueue implements a reliable queue that will encrypt and send accepted // backups to the watchtower specified in the config's ClientSession. Calling // Stop will attempt to perform a clean shutdown replaying any un-committed -// pending updates to the TowerClient's main task pipeline. +// pending updates to the client's main task pipeline. type sessionQueue struct { started sync.Once stopped sync.Once diff --git a/watchtower/wtclient/stats.go b/watchtower/wtclient/stats.go index b41783ff3..35303e925 100644 --- a/watchtower/wtclient/stats.go +++ b/watchtower/wtclient/stats.go @@ -8,8 +8,6 @@ import ( // ClientStats is a collection of in-memory statistics of the actions the client // has performed since its creation. type ClientStats struct { - mu sync.Mutex - // NumTasksPending is the total number of backups that are pending to // be acknowledged by all active and exhausted watchtower sessions. NumTasksPending int @@ -31,68 +29,77 @@ type ClientStats struct { NumSessionsExhausted int } +// clientStats wraps ClientStats with a mutex so that it's members can be +// accessed in a thread safe manner. +type clientStats struct { + mu sync.Mutex + + ClientStats +} + // taskReceived increments the number of backup requests the client has received // from active channels. -func (s *ClientStats) taskReceived() { +func (s *clientStats) taskReceived() { s.mu.Lock() defer s.mu.Unlock() + s.NumTasksPending++ } // taskAccepted increments the number of tasks that have been assigned to active // session queues, and are awaiting upload to a tower. -func (s *ClientStats) taskAccepted() { +func (s *clientStats) taskAccepted() { s.mu.Lock() defer s.mu.Unlock() + s.NumTasksAccepted++ s.NumTasksPending-- } +// getStatsCopy returns a copy of the ClientStats. +func (s *clientStats) getStatsCopy() ClientStats { + s.mu.Lock() + defer s.mu.Unlock() + + return s.ClientStats +} + // taskIneligible increments the number of tasks that were unable to satisfy the // active session queue's policy. These can potentially be retried later, but // typically this means that the balance created dust outputs, so it may not be // worth backing up at all. -func (s *ClientStats) taskIneligible() { +func (s *clientStats) taskIneligible() { s.mu.Lock() defer s.mu.Unlock() + s.NumTasksIneligible++ } // sessionAcquired increments the number of sessions that have been successfully // negotiated by the client during this execution. -func (s *ClientStats) sessionAcquired() { +func (s *clientStats) sessionAcquired() { s.mu.Lock() defer s.mu.Unlock() + s.NumSessionsAcquired++ } // sessionExhausted increments the number of session that have become full as a // result of accepting backup tasks. -func (s *ClientStats) sessionExhausted() { +func (s *clientStats) sessionExhausted() { s.mu.Lock() defer s.mu.Unlock() + s.NumSessionsExhausted++ } // String returns a human-readable summary of the client's metrics. -func (s *ClientStats) String() string { +func (s *clientStats) String() string { s.mu.Lock() defer s.mu.Unlock() + return fmt.Sprintf("tasks(received=%d accepted=%d ineligible=%d) "+ "sessions(acquired=%d exhausted=%d)", s.NumTasksPending, s.NumTasksAccepted, s.NumTasksIneligible, s.NumSessionsAcquired, s.NumSessionsExhausted) } - -// Copy returns a copy of the current stats. -func (s *ClientStats) Copy() ClientStats { - s.mu.Lock() - defer s.mu.Unlock() - return ClientStats{ - NumTasksPending: s.NumTasksPending, - NumTasksAccepted: s.NumTasksAccepted, - NumTasksIneligible: s.NumTasksIneligible, - NumSessionsAcquired: s.NumSessionsAcquired, - NumSessionsExhausted: s.NumSessionsExhausted, - } -}