diff --git a/keychain/derivation.go b/keychain/derivation.go index d908da757..54144efca 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -90,6 +90,12 @@ const ( // a payment, or self stored on disk in a single file containing all // the static channel backups. KeyFamilyStaticBackup KeyFamily = 7 + + // KeyFamilyTowerSession is the family of keys that will be used to + // derive session keys when negotiating sessions with watchtowers. The + // session keys are limited to the lifetime of the session and are used + // to increase privacy in the watchtower protocol. + KeyFamilyTowerSession KeyFamily = 8 ) // KeyLocator is a two-tuple that can be used to derive *any* key that has ever diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 9bd8e9e8a..8ecb8c6db 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtdb" @@ -77,7 +76,7 @@ type Config struct { // SecretKeyRing is used to derive the session keys used to communicate // with the tower. The client only stores the KeyLocators internally so // that we never store private keys on disk. - SecretKeyRing keychain.SecretKeyRing + SecretKeyRing SecretKeyRing // Dial connects to an addr using the specified net and returns the // connection object. @@ -201,15 +200,16 @@ func New(config *Config) (*TowerClient, error) { forceQuit: make(chan struct{}), } c.negotiator = newSessionNegotiator(&NegotiatorConfig{ - DB: cfg.DB, - Policy: cfg.Policy, - ChainHash: cfg.ChainHash, - SendMessage: c.sendMessage, - ReadMessage: c.readMessage, - Dial: c.dial, - Candidates: newTowerListIterator(tower), - MinBackoff: cfg.MinBackoff, - MaxBackoff: cfg.MaxBackoff, + DB: cfg.DB, + SecretKeyRing: cfg.SecretKeyRing, + Policy: cfg.Policy, + ChainHash: cfg.ChainHash, + SendMessage: c.sendMessage, + ReadMessage: c.readMessage, + Dial: c.dial, + Candidates: newTowerListIterator(tower), + MinBackoff: cfg.MinBackoff, + MaxBackoff: cfg.MaxBackoff, }) // Next, load all active sessions from the db into the client. We will @@ -221,6 +221,28 @@ func New(config *Config) (*TowerClient, error) { return nil, err } + // Reload any towers from disk using the tower IDs contained in each + // candidate session. We will also rederive any session keys needed to + // be able to communicate with the towers and authenticate session + // requests. This prevents us from having to store the private keys on + // disk. + for _, s := range c.candidateSessions { + tower, err := c.cfg.DB.LoadTower(s.TowerID) + if err != nil { + return nil, err + } + + sessionPriv, err := DeriveSessionKey( + c.cfg.SecretKeyRing, s.KeyIndex, + ) + if err != nil { + return nil, err + } + + s.Tower = tower + s.SessionPrivKey = sessionPriv + } + // Finally, load the sweep pkscripts that have been generated for all // previously registered channels. c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts() @@ -334,9 +356,6 @@ func (c *TowerClient) ForceQuit() { c.forced.Do(func() { log.Infof("Force quitting watchtower client") - // Cancel log message from stop. - close(c.forceQuit) - // 1. Shutdown the backup queue, which will prevent any further // updates from being accepted. In practice, the links should be // shutdown before the client has been stopped, so all updates @@ -347,6 +366,7 @@ func (c *TowerClient) ForceQuit() { // dispatcher to exit. The backup queue will signal it's // completion to the dispatcher, which releases the wait group // after all tasks have been assigned to session queues. + close(c.forceQuit) c.wg.Wait() // 3. Since all valid tasks have been assigned to session @@ -490,6 +510,9 @@ func (c *TowerClient) backupDispatcher() { case <-c.statTicker.C: log.Infof("Client stats: %s", c.stats) + + case <-c.forceQuit: + return } // No active session queue but have additional sessions. diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index dba4275e6..68c34e0e1 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -379,10 +379,11 @@ type testHarness struct { } type harnessCfg struct { - localBalance lnwire.MilliSatoshi - remoteBalance lnwire.MilliSatoshi - policy wtpolicy.Policy - noRegisterChan0 bool + localBalance lnwire.MilliSatoshi + remoteBalance lnwire.MilliSatoshi + policy wtpolicy.Policy + noRegisterChan0 bool + noAckCreateSession bool } func newHarness(t *testing.T, cfg harnessCfg) *testHarness { @@ -414,6 +415,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NewAddress: func() (btcutil.Address, error) { return addr, nil }, + NoAckCreateSession: cfg.noAckCreateSession, } server, err := wtserver.New(serverCfg) @@ -430,10 +432,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { Dial: func(string, string) (net.Conn, error) { return nil, nil }, - DB: clientDB, - AuthDial: mockNet.AuthDial, - PrivateTower: towerAddr, - Policy: cfg.policy, + DB: clientDB, + AuthDial: mockNet.AuthDial, + SecretKeyRing: wtmock.NewSecretKeyRing(), + PrivateTower: towerAddr, + Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, @@ -729,6 +732,36 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint, } } +// assertUpdatesForPolicy queries the server db for matches using the provided +// breach hints, then asserts that each match has a session with the expected +// policy. +func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint, + expPolicy wtpolicy.Policy) { + + // Query for matches on the provided hints. + matches, err := h.serverDB.QueryMatches(hints) + if err != nil { + h.t.Fatalf("unable to query for matches: %v", err) + } + + // Assert that the number of matches is exactly the number of provided + // hints. + if len(matches) != len(hints) { + h.t.Fatalf("expected: %d matches, got: %d", len(hints), + len(matches)) + } + + // Assert that all of the matches correspond to a session with the + // expected policy. + for _, match := range matches { + matchPolicy := match.SessionInfo.Policy + if expPolicy != matchPolicy { + h.t.Fatalf("expected session to have policy: %v, "+ + "got: %v", expPolicy, matchPolicy) + } + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) @@ -1098,6 +1131,119 @@ var clientTests = []clientTest{ h.waitServerUpdates(hints, 10*time.Second) }, }, + { + name: "create session no ack", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + noAckCreateSession: true, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 3 + ) + + // Generate the retributions that will be backed up. + hints := h.advanceChannelN(chanID, numUpdates) + + // Now, queue the retributions for backup. + h.backupStates(chanID, 0, numUpdates, nil) + + // Since the client is unable to create a session, the + // server should have no updates. + h.waitServerUpdates(nil, time.Second) + + // Force quit the client since it has queued backups. + h.client.ForceQuit() + + // Restart the server and allow it to ack session + // creation. + h.server.Stop() + h.serverCfg.NoAckCreateSession = false + h.startServer() + defer h.server.Stop() + + // Restart the client with the same policy, which will + // immediately try to overwrite the old session with an + // identical one. + h.startClient() + defer h.client.ForceQuit() + + // Now, queue the retributions for backup. + h.backupStates(chanID, 0, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 5*time.Second) + + // Assert that the server has updates for the clients + // most recent policy. + h.assertUpdatesForPolicy(hints, h.clientCfg.Policy) + }, + }, + { + name: "create session no ack change policy", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + BlobType: blob.TypeDefault, + MaxUpdates: 5, + SweepFeeRate: 1, + }, + noAckCreateSession: true, + }, + fn: func(h *testHarness) { + const ( + chanID = 0 + numUpdates = 3 + ) + + // Generate the retributions that will be backed up. + hints := h.advanceChannelN(chanID, numUpdates) + + // Now, queue the retributions for backup. + h.backupStates(chanID, 0, numUpdates, nil) + + // Since the client is unable to create a session, the + // server should have no updates. + h.waitServerUpdates(nil, time.Second) + + // Force quit the client since it has queued backups. + h.client.ForceQuit() + + // Restart the server and allow it to ack session + // creation. + h.server.Stop() + h.serverCfg.NoAckCreateSession = false + h.startServer() + defer h.server.Stop() + + // Restart the client with a new policy, which will + // immediately try to overwrite the prior session with + // the old policy. + h.clientCfg.Policy.SweepFeeRate = 2 + h.startClient() + defer h.client.ForceQuit() + + // Now, queue the retributions for backup. + h.backupStates(chanID, 0, numUpdates, nil) + + // Wait for all of the updates to be populated in the + // server's database. + h.waitServerUpdates(hints, 5*time.Second) + + // Assert that the server has updates for the clients + // most recent policy. + h.assertUpdatesForPolicy(hints, h.clientCfg.Policy) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup diff --git a/watchtower/wtclient/derivation.go b/watchtower/wtclient/derivation.go new file mode 100644 index 000000000..cb0caec00 --- /dev/null +++ b/watchtower/wtclient/derivation.go @@ -0,0 +1,24 @@ +package wtclient + +import ( + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/keychain" +) + +// DeriveSessionKey accepts an session key index for an existing session and +// derives the HD private key to be used to authenticate the brontide transport +// and authenticate requests sent to the tower. The key will use the +// keychain.KeyFamilyTowerSession and the provided index, giving a BIP43 +// derivation path of: +// +// * m/1017'/coinType'/8/0/index +func DeriveSessionKey(keyRing SecretKeyRing, + index uint32) (*btcec.PrivateKey, error) { + + return keyRing.DerivePrivKey(keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: index, + }, + }) +} diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5164acea8..5aef86196 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtserver" @@ -19,6 +20,17 @@ type DB interface { // sessions. CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error) + // LoadTower retrieves a tower by its tower ID. + LoadTower(uint64) (*wtdb.Tower, error) + + // NextSessionKeyIndex reserves a new session key derivation index for a + // particular tower id. The index is reserved for that tower until + // CreateClientSession is invoked for that tower and index, at which + // point a new index for that tower can be reserved. Multiple calls to + // this method before CreateClientSession is invoked should return the + // same index. + NextSessionKeyIndex(uint64) (uint32, error) + // CreateClientSession saves a newly negotiated client session to the // client's database. This enables the session to be used across // restarts. @@ -74,3 +86,11 @@ func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, return brontide.Dial(localPriv, netAddr, dialer) } + +// SecretKeyRing abstracts the ability to derive HD private keys given a +// description of the derivation path. +type SecretKeyRing interface { + // DerivePrivKey derives the private key from the root seed using a + // key descriptor specifying the key's derivation path. + DerivePrivKey(loc keychain.KeyDescriptor) (*btcec.PrivateKey, error) +} diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index b62819cb3..e16296fa8 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -42,6 +42,10 @@ type NegotiatorConfig struct { // negotiated sessions. DB DB + // SecretKeyRing allows the client to derive new session private keys + // when attempting to negotiate session with a tower. + SecretKeyRing SecretKeyRing + // Candidates is an abstract set of tower candidates that the negotiator // will traverse serially when attempting to negotiate a new session. Candidates TowerCandidateIterator @@ -224,7 +228,7 @@ func (n *sessionNegotiator) negotiate() { // On the first pass, initialize the backoff to our configured min // backoff. - backoff := n.cfg.MinBackoff + var backoff time.Duration retryWithBackoff: // If we are retrying, wait out the delay before continuing. @@ -240,13 +244,24 @@ retryWithBackoff: // iterator to ensure the results are fresh. n.cfg.Candidates.Reset() for { + select { + case <-n.quit: + return + default: + } + // Pull the next candidate from our list of addresses. tower, err := n.cfg.Candidates.Next() if err != nil { - // We've run out of addresses, double and clamp backoff. - backoff *= 2 - if backoff > n.cfg.MaxBackoff { - backoff = n.cfg.MaxBackoff + if backoff == 0 { + backoff = n.cfg.MinBackoff + } else { + // We've run out of addresses, double and clamp + // backoff. + backoff *= 2 + if backoff > n.cfg.MaxBackoff { + backoff = n.cfg.MaxBackoff + } } log.Debugf("Unable to get new tower candidate, "+ @@ -255,12 +270,23 @@ retryWithBackoff: goto retryWithBackoff } + towerPub := tower.IdentityKey.SerializeCompressed() log.Debugf("Attempting session negotiation with tower=%x", - tower.IdentityKey.SerializeCompressed()) + towerPub) + + // Before proceeding, we will reserve a session key index to use + // with this specific tower. If one is already reserved, the + // existing index will be returned. + keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID) + if err != nil { + log.Debugf("Unable to reserve session key index "+ + "for tower=%x: %v", towerPub, err) + continue + } // We'll now attempt the CreateSession dance with the tower to // get a new session, trying all addresses if necessary. - err = n.createSession(tower) + err = n.createSession(tower, keyIndex) if err != nil { log.Debugf("Session negotiation with tower=%x "+ "failed, trying again -- reason: %v", @@ -277,22 +303,21 @@ retryWithBackoff: // its stored addresses. This method returns after the first successful // negotiation, or after all addresses have failed with ErrFailedNegotiation. If // the tower has no addresses, ErrNoTowerAddrs is returned. -func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error { +func (n *sessionNegotiator) createSession(tower *wtdb.Tower, + keyIndex uint32) error { + // If the tower has no addresses, there's nothing we can do. if len(tower.Addresses) == 0 { return ErrNoTowerAddrs } - // TODO(conner): create with hdkey at random index - sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256()) + sessionPriv, err := DeriveSessionKey(n.cfg.SecretKeyRing, keyIndex) if err != nil { return err } - // TODO(conner): write towerAddr+privkey - for _, lnAddr := range tower.LNAddrs() { - err = n.tryAddress(sessionPrivKey, tower, lnAddr) + err = n.tryAddress(sessionPriv, keyIndex, tower, lnAddr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -318,7 +343,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error { // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, - tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(privKey, lnAddr) @@ -394,7 +419,8 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey, clientSession := &wtdb.ClientSession{ TowerID: tower.ID, Tower: tower, - SessionPrivKey: privKey, // remove after using HD keys + KeyIndex: keyIndex, + SessionPrivKey: privKey, ID: sessionID, Policy: n.cfg.Policy, SeqNum: 0, diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index f7b531fec..583472e66 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/btcsuite/btcd/btcec" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -30,6 +29,15 @@ var ( // LastApplied value greater than any allocated sequence number. ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " + "greater than allocated seqnum") + + // ErrNoReservedKeyIndex signals that a client session could not be + // created because no session key index was reserved. + ErrNoReservedKeyIndex = errors.New("key index not reserved") + + // ErrIncorrectKeyIndex signals that the client session could not be + // created because session key index differs from the reserved key + // index. + ErrIncorrectKeyIndex = errors.New("incorrect key index") ) // ClientSession encapsulates a SessionInfo returned from a successful @@ -57,14 +65,17 @@ type ClientSession struct { // tower with TowerID. Tower *Tower - // SessionKeyDesc is the key descriptor used to derive the client's + // KeyIndex is the index of key locator used to derive the client's // session key so that it can authenticate with the tower to update its - // session. - SessionKeyDesc keychain.KeyLocator + // session. In order to rederive the private key, the key locator should + // use the keychain.KeyFamilyTowerSession key family. + KeyIndex uint32 // SessionPrivKey is the ephemeral secret key used to connect to the // watchtower. - // TODO(conner): remove after HD keys + // + // NOTE: This value is not serialized. It is derived using the KeyIndex + // on startup to avoid storing private keys on disk. SessionPrivKey *btcec.PrivateKey // Policy holds the negotiated session parameters. diff --git a/watchtower/wtdb/mock.go b/watchtower/wtdb/mock.go index fa53d6cf7..902303881 100644 --- a/watchtower/wtdb/mock.go +++ b/watchtower/wtdb/mock.go @@ -61,7 +61,8 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error { db.mu.Lock() defer db.mu.Unlock() - if _, ok := db.sessions[info.ID]; ok { + dbInfo, ok := db.sessions[info.ID] + if ok && dbInfo.LastApplied > 0 { return ErrSessionAlreadyExists } diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index e7213cab4..ff7a48df2 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -1,6 +1,7 @@ package wtdb import ( + "errors" "net" "sync" @@ -8,6 +9,12 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +var ( + // ErrTowerNotFound signals that the target tower was not found in the + // database. + ErrTowerNotFound = errors.New("tower not found") +) + // Tower holds the necessary components required to connect to a remote tower. // Communication is handled by brontide, and requires both a public key and an // address. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 54f9a697e..a075e7d9f 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -22,6 +22,9 @@ type ClientDB struct { activeSessions map[wtdb.SessionID]*wtdb.ClientSession towerIndex map[towerPK]uint64 towers map[uint64]*wtdb.Tower + + nextIndex uint32 + indexes map[uint64]uint32 } // NewClientDB initializes a new mock ClientDB. @@ -31,6 +34,7 @@ func NewClientDB() *ClientDB { activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession), towerIndex: make(map[towerPK]uint64), towers: make(map[uint64]*wtdb.Tower), + indexes: make(map[uint64]uint32), } } @@ -64,6 +68,18 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { return tower, nil } +// LoadTower retrieves a tower by its tower ID. +func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if tower, ok := m.towers[towerID]; ok { + return tower, nil + } + + return nil, wtdb.ErrTowerNotFound +} + // MarkBackupIneligible records that particular commit height is ineligible for // backup. This allows the client to track which updates it should not attempt // to retry after startup. @@ -90,16 +106,29 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { m.mu.Lock() defer m.mu.Unlock() + // Ensure that a session key index has been reserved for this tower. + keyIndex, ok := m.indexes[session.TowerID] + if !ok { + return wtdb.ErrNoReservedKeyIndex + } + + // Ensure that the session's index matches the reserved index. + if keyIndex != session.KeyIndex { + return wtdb.ErrIncorrectKeyIndex + } + + // Remove the key index reservation for this tower. Once committed, this + // permits us to create another session with this tower. + delete(m.indexes, session.TowerID) + m.activeSessions[session.ID] = &wtdb.ClientSession{ TowerID: session.TowerID, - Tower: session.Tower, - SessionKeyDesc: session.SessionKeyDesc, - SessionPrivKey: session.SessionPrivKey, + KeyIndex: session.KeyIndex, ID: session.ID, Policy: session.Policy, SeqNum: session.SeqNum, TowerLastApplied: session.TowerLastApplied, - RewardPkScript: session.RewardPkScript, + RewardPkScript: cloneBytes(session.RewardPkScript), CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate), AckedUpdates: make(map[uint16]wtdb.BackupID), } @@ -107,6 +136,27 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { return nil } +// NextSessionKeyIndex reserves a new session key derivation index for a +// particular tower id. The index is reserved for that tower until +// CreateClientSession is invoked for that tower and index, at which point a new +// index for that tower can be reserved. Multiple calls to this method before +// CreateClientSession is invoked should return the same index. +func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if index, ok := m.indexes[towerID]; ok { + return index, nil + } + + index := m.nextIndex + m.indexes[towerID] = index + + m.nextIndex++ + + return index, nil +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16, @@ -217,7 +267,12 @@ func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) err } func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + bb := make([]byte, len(b)) copy(bb, b) + return bb } diff --git a/watchtower/wtmock/keyring.go b/watchtower/wtmock/keyring.go new file mode 100644 index 000000000..f18a0fa8e --- /dev/null +++ b/watchtower/wtmock/keyring.go @@ -0,0 +1,44 @@ +package wtmock + +import ( + "sync" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/keychain" +) + +// SecretKeyRing is a mock, in-memory implementation for deriving private keys. +type SecretKeyRing struct { + mu sync.Mutex + keys map[keychain.KeyLocator]*btcec.PrivateKey +} + +// NewSecretKeyRing creates a new mock SecretKeyRing. +func NewSecretKeyRing() *SecretKeyRing { + return &SecretKeyRing{ + keys: make(map[keychain.KeyLocator]*btcec.PrivateKey), + } +} + +// DerivePrivKey derives the private key for a given key descriptor. If +// this method is called twice with the same argument, it will return the same +// private key. +func (m *SecretKeyRing) DerivePrivKey( + desc keychain.KeyDescriptor) (*btcec.PrivateKey, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + if key, ok := m.keys[desc.KeyLocator]; ok { + return key, nil + } + + privKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, err + } + + m.keys[desc.KeyLocator] = privKey + + return privKey, nil +} diff --git a/watchtower/wtserver/create_session.go b/watchtower/wtserver/create_session.go index e948d8b9a..411742e0a 100644 --- a/watchtower/wtserver/create_session.go +++ b/watchtower/wtserver/create_session.go @@ -21,45 +21,26 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, existingInfo, err := s.cfg.DB.GetSessionInfo(id) switch { + // We already have a session, though it is currently unused. We'll allow + // the client to recommit the session if it wanted to change the policy. + case err == nil && existingInfo.LastApplied == 0: + // We already have a session corresponding to this session id, return an // error signaling that it already exists in our database. We return the // reward address to the client in case they were not able to process // our reply earlier. - case err == nil: + case err == nil && existingInfo.LastApplied > 0: log.Debugf("Already have session for %s", id) return s.replyCreateSession( peer, id, wtwire.CreateSessionCodeAlreadyExists, - existingInfo.RewardAddress, + existingInfo.LastApplied, existingInfo.RewardAddress, ) // Some other database error occurred, return a temporary failure. case err != wtdb.ErrSessionNotFound: log.Errorf("unable to load session info for %s", id) return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - // Now that we've established that this session does not exist in the - // database, retrieve the sweep address that will be given to the - // client. This address is to be included by the client when signing - // sweep transactions destined for this tower, if its negotiated output - // is not dust. - rewardAddress, err := s.cfg.NewAddress() - if err != nil { - log.Errorf("unable to generate reward addr for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, - ) - } - - // Construct the pkscript the client should pay to when signing justice - // transactions for this session. - rewardScript, err := txscript.PayToAddrScript(rewardAddress) - if err != nil { - log.Errorf("unable to generate reward script for %s", id) - return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, + peer, id, wtwire.CodeTemporaryFailure, 0, nil, ) } @@ -68,10 +49,39 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, log.Debugf("Rejecting CreateSession from %s, unsupported blob "+ "type %s", id, req.BlobType) return s.replyCreateSession( - peer, id, wtwire.CreateSessionCodeRejectBlobType, nil, + peer, id, wtwire.CreateSessionCodeRejectBlobType, 0, + nil, ) } + // Now that we've established that this session does not exist in the + // database, retrieve the sweep address that will be given to the + // client. This address is to be included by the client when signing + // sweep transactions destined for this tower, if its negotiated output + // is not dust. + var rewardScript []byte + if req.BlobType.Has(blob.FlagReward) { + rewardAddress, err := s.cfg.NewAddress() + if err != nil { + log.Errorf("Unable to generate reward addr for %s: %v", + id, err) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, 0, nil, + ) + } + + // Construct the pkscript the client should pay to when signing + // justice transactions for this session. + rewardScript, err = txscript.PayToAddrScript(rewardAddress) + if err != nil { + log.Errorf("Unable to generate reward script for "+ + "%s: %v", id, err) + return s.replyCreateSession( + peer, id, wtwire.CodeTemporaryFailure, 0, nil, + ) + } + } + // TODO(conner): create invoice for upfront payment // Assemble the session info using the agreed upon parameters, reward @@ -94,14 +104,14 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, if err != nil { log.Errorf("unable to create session for %s", id) return s.replyCreateSession( - peer, id, wtwire.CodeTemporaryFailure, nil, + peer, id, wtwire.CodeTemporaryFailure, 0, nil, ) } log.Infof("Accepted session for %s", id) return s.replyCreateSession( - peer, id, wtwire.CodeOK, rewardScript, + peer, id, wtwire.CodeOK, 0, rewardScript, ) } @@ -110,11 +120,19 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, // Otherwise, this method returns a connection error to ensure we don't continue // communication with the client. func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID, - code wtwire.ErrorCode, data []byte) error { + code wtwire.ErrorCode, lastApplied uint16, data []byte) error { + + if s.cfg.NoAckCreateSession { + return &connFailure{ + ID: *id, + Code: code, + } + } msg := &wtwire.CreateSessionReply{ - Code: code, - Data: data, + Code: code, + LastApplied: lastApplied, + Data: data, } err := s.sendMessage(peer, msg) @@ -131,6 +149,6 @@ func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID, // disconnect the client. return &connFailure{ ID: *id, - Code: uint16(code), + Code: code, } } diff --git a/watchtower/wtserver/delete_session.go b/watchtower/wtserver/delete_session.go index a5b517c12..00b9b1658 100644 --- a/watchtower/wtserver/delete_session.go +++ b/watchtower/wtserver/delete_session.go @@ -52,6 +52,6 @@ func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID, // disconnect the client. return &connFailure{ ID: *id, - Code: uint16(code), + Code: code, } } diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index 4a49c3ad4..d4ee88741 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -56,6 +56,10 @@ type Config struct { // ChainHash identifies the network that the server is watching. ChainHash chainhash.Hash + // NoAckCreateSession causes the server to not reply to create session + // requests, this should only be used for testing. + NoAckCreateSession bool + // NoAckUpdates causes the server to not acknowledge state updates, this // should only be used for testing. NoAckUpdates bool @@ -283,12 +287,12 @@ func (s *Server) handleClient(peer Peer) { // error code. type connFailure struct { ID wtdb.SessionID - Code uint16 + Code wtwire.ErrorCode } // Error displays the SessionID and Code that caused the connection failure. func (f *connFailure) Error() string { - return fmt.Sprintf("connection with %s failed with code=%v", + return fmt.Sprintf("connection with %s failed with code=%s", f.ID, f.Code, ) } diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index d9e3bec50..6d99180fd 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -29,6 +29,8 @@ var ( addrScript, _ = txscript.PayToAddrScript(addr) testnetChainHash = *chaincfg.TestNet3Params.GenesisHash + + rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type() ) // randPubKey generates a new secp keypair, and returns the public key. @@ -152,16 +154,17 @@ func TestServerOnlyAcceptOnePeer(t *testing.T) { } type createSessionTestCase struct { - name string - initMsg *wtwire.Init - createMsg *wtwire.CreateSession - expReply *wtwire.CreateSessionReply - expDupReply *wtwire.CreateSessionReply + name string + initMsg *wtwire.Init + createMsg *wtwire.CreateSession + expReply *wtwire.CreateSessionReply + expDupReply *wtwire.CreateSessionReply + sendStateUpdate bool } var createSessionTests = []createSessionTestCase{ { - name: "reject duplicate session create", + name: "duplicate session create", initMsg: wtwire.NewInitMessage( lnwire.NewRawFeatureVector(), testnetChainHash, @@ -173,12 +176,58 @@ var createSessionTests = []createSessionTestCase{ RewardRate: 0, SweepFeeRate: 1, }, + expReply: &wtwire.CreateSessionReply{ + Code: wtwire.CodeOK, + Data: []byte{}, + }, + expDupReply: &wtwire.CreateSessionReply{ + Code: wtwire.CodeOK, + Data: []byte{}, + }, + }, + { + name: "duplicate session create after use", + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), + createMsg: &wtwire.CreateSession{ + BlobType: blob.TypeDefault, + MaxUpdates: 1000, + RewardBase: 0, + RewardRate: 0, + SweepFeeRate: 1, + }, + expReply: &wtwire.CreateSessionReply{ + Code: wtwire.CodeOK, + Data: []byte{}, + }, + expDupReply: &wtwire.CreateSessionReply{ + Code: wtwire.CreateSessionCodeAlreadyExists, + LastApplied: 1, + Data: []byte{}, + }, + sendStateUpdate: true, + }, + { + name: "duplicate session create reward", + initMsg: wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(), + testnetChainHash, + ), + createMsg: &wtwire.CreateSession{ + BlobType: rewardType, + MaxUpdates: 1000, + RewardBase: 0, + RewardRate: 0, + SweepFeeRate: 1, + }, expReply: &wtwire.CreateSessionReply{ Code: wtwire.CodeOK, Data: addrScript, }, expDupReply: &wtwire.CreateSessionReply{ - Code: wtwire.CreateSessionCodeAlreadyExists, + Code: wtwire.CodeOK, Data: addrScript, }, }, @@ -251,6 +300,18 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) { return } + if test.sendStateUpdate { + peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) + connect(t, s, peer, test.initMsg, timeoutDuration) + update := &wtwire.StateUpdate{ + SeqNum: 1, + IsComplete: 1, + } + sendMsg(t, update, peer, timeoutDuration) + + assertConnClosed(t, peer, 2*timeoutDuration) + } + // Simulate a peer with the same session id connection to the server // again. peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0) @@ -705,7 +766,7 @@ func TestServerDeleteSession(t *testing.T) { send: createSession, recv: &wtwire.CreateSessionReply{ Code: wtwire.CodeOK, - Data: addrScript, + Data: []byte{}, }, assert: func(t *testing.T) { // Both peers should have sessions. diff --git a/watchtower/wtserver/state_update.go b/watchtower/wtserver/state_update.go index 7b3e0941b..63a0a37d6 100644 --- a/watchtower/wtserver/state_update.go +++ b/watchtower/wtserver/state_update.go @@ -117,7 +117,7 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID, if s.cfg.NoAckUpdates { return &connFailure{ ID: *id, - Code: uint16(failCode), + Code: failCode, } } @@ -152,6 +152,6 @@ func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID, // disconnect the client. return &connFailure{ ID: *id, - Code: uint16(code), + Code: code, } } diff --git a/watchtower/wtwire/create_session_reply.go b/watchtower/wtwire/create_session_reply.go index 9b63b08c7..82d2b6ab4 100644 --- a/watchtower/wtwire/create_session_reply.go +++ b/watchtower/wtwire/create_session_reply.go @@ -43,6 +43,12 @@ type CreateSessionReply struct { // Code will be non-zero if the watchtower rejected the session init. Code CreateSessionCode + // LastApplied is the tower's last accepted sequence number for the + // session. This is useful when the session already exists but the + // client doesn't realize it's already used the session, such as after a + // restoration. + LastApplied uint16 + // Data is a byte slice returned the caller of the message, and is to be // interpreted according to the error Code. When the response is // CreateSessionCodeOK, data encodes the reward address to be included in @@ -63,6 +69,7 @@ var _ Message = (*CreateSessionReply)(nil) func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &m.Code, + &m.LastApplied, &m.Data, ) } @@ -74,6 +81,7 @@ func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error { func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error { return WriteElements(w, m.Code, + m.LastApplied, m.Data, ) } diff --git a/watchtower/wtwire/delete_session_reply.go b/watchtower/wtwire/delete_session_reply.go index 059d189e8..e25664a5d 100644 --- a/watchtower/wtwire/delete_session_reply.go +++ b/watchtower/wtwire/delete_session_reply.go @@ -12,8 +12,6 @@ const ( // client side, or that the tower had already deleted the session in a // prior request that the client may not have received. DeleteSessionCodeNotFound DeleteSessionCode = 80 - - // TODO(conner): add String method after wtclient is merged ) // DeleteSessionReply is a message sent in response to a client's DeleteSession diff --git a/watchtower/wtwire/error_code.go b/watchtower/wtwire/error_code.go index 2f4bc6bb3..c614f2a6e 100644 --- a/watchtower/wtwire/error_code.go +++ b/watchtower/wtwire/error_code.go @@ -46,6 +46,8 @@ func (c ErrorCode) String() string { return "StateUpdateCodeMaxUpdatesExceeded" case StateUpdateCodeSeqNumOutOfOrder: return "StateUpdateCodeSeqNumOutOfOrder" + case DeleteSessionCodeNotFound: + return "DeleteSessionCodeNotFound" default: return fmt.Sprintf("UnknownErrorCode: %d", c) }