mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-10 22:52:41 +02:00
Merge pull request #2820 from cfromknecht/session-key-derivation
wtclient: session private key derivation
This commit is contained in:
@@ -90,6 +90,12 @@ const (
|
|||||||
// a payment, or self stored on disk in a single file containing all
|
// a payment, or self stored on disk in a single file containing all
|
||||||
// the static channel backups.
|
// the static channel backups.
|
||||||
KeyFamilyStaticBackup KeyFamily = 7
|
KeyFamilyStaticBackup KeyFamily = 7
|
||||||
|
|
||||||
|
// KeyFamilyTowerSession is the family of keys that will be used to
|
||||||
|
// derive session keys when negotiating sessions with watchtowers. The
|
||||||
|
// session keys are limited to the lifetime of the session and are used
|
||||||
|
// to increase privacy in the watchtower protocol.
|
||||||
|
KeyFamilyTowerSession KeyFamily = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
||||||
|
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/lightningnetwork/lnd/input"
|
"github.com/lightningnetwork/lnd/input"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwallet"
|
"github.com/lightningnetwork/lnd/lnwallet"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
@@ -77,7 +76,7 @@ type Config struct {
|
|||||||
// SecretKeyRing is used to derive the session keys used to communicate
|
// SecretKeyRing is used to derive the session keys used to communicate
|
||||||
// with the tower. The client only stores the KeyLocators internally so
|
// with the tower. The client only stores the KeyLocators internally so
|
||||||
// that we never store private keys on disk.
|
// that we never store private keys on disk.
|
||||||
SecretKeyRing keychain.SecretKeyRing
|
SecretKeyRing SecretKeyRing
|
||||||
|
|
||||||
// Dial connects to an addr using the specified net and returns the
|
// Dial connects to an addr using the specified net and returns the
|
||||||
// connection object.
|
// connection object.
|
||||||
@@ -202,6 +201,7 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
}
|
}
|
||||||
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
|
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
|
SecretKeyRing: cfg.SecretKeyRing,
|
||||||
Policy: cfg.Policy,
|
Policy: cfg.Policy,
|
||||||
ChainHash: cfg.ChainHash,
|
ChainHash: cfg.ChainHash,
|
||||||
SendMessage: c.sendMessage,
|
SendMessage: c.sendMessage,
|
||||||
@@ -221,6 +221,28 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reload any towers from disk using the tower IDs contained in each
|
||||||
|
// candidate session. We will also rederive any session keys needed to
|
||||||
|
// be able to communicate with the towers and authenticate session
|
||||||
|
// requests. This prevents us from having to store the private keys on
|
||||||
|
// disk.
|
||||||
|
for _, s := range c.candidateSessions {
|
||||||
|
tower, err := c.cfg.DB.LoadTower(s.TowerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionPriv, err := DeriveSessionKey(
|
||||||
|
c.cfg.SecretKeyRing, s.KeyIndex,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Tower = tower
|
||||||
|
s.SessionPrivKey = sessionPriv
|
||||||
|
}
|
||||||
|
|
||||||
// Finally, load the sweep pkscripts that have been generated for all
|
// Finally, load the sweep pkscripts that have been generated for all
|
||||||
// previously registered channels.
|
// previously registered channels.
|
||||||
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
|
c.sweepPkScripts, err = c.cfg.DB.FetchChanPkScripts()
|
||||||
@@ -334,9 +356,6 @@ func (c *TowerClient) ForceQuit() {
|
|||||||
c.forced.Do(func() {
|
c.forced.Do(func() {
|
||||||
log.Infof("Force quitting watchtower client")
|
log.Infof("Force quitting watchtower client")
|
||||||
|
|
||||||
// Cancel log message from stop.
|
|
||||||
close(c.forceQuit)
|
|
||||||
|
|
||||||
// 1. Shutdown the backup queue, which will prevent any further
|
// 1. Shutdown the backup queue, which will prevent any further
|
||||||
// updates from being accepted. In practice, the links should be
|
// updates from being accepted. In practice, the links should be
|
||||||
// shutdown before the client has been stopped, so all updates
|
// shutdown before the client has been stopped, so all updates
|
||||||
@@ -347,6 +366,7 @@ func (c *TowerClient) ForceQuit() {
|
|||||||
// dispatcher to exit. The backup queue will signal it's
|
// dispatcher to exit. The backup queue will signal it's
|
||||||
// completion to the dispatcher, which releases the wait group
|
// completion to the dispatcher, which releases the wait group
|
||||||
// after all tasks have been assigned to session queues.
|
// after all tasks have been assigned to session queues.
|
||||||
|
close(c.forceQuit)
|
||||||
c.wg.Wait()
|
c.wg.Wait()
|
||||||
|
|
||||||
// 3. Since all valid tasks have been assigned to session
|
// 3. Since all valid tasks have been assigned to session
|
||||||
@@ -490,6 +510,9 @@ func (c *TowerClient) backupDispatcher() {
|
|||||||
|
|
||||||
case <-c.statTicker.C:
|
case <-c.statTicker.C:
|
||||||
log.Infof("Client stats: %s", c.stats)
|
log.Infof("Client stats: %s", c.stats)
|
||||||
|
|
||||||
|
case <-c.forceQuit:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// No active session queue but have additional sessions.
|
// No active session queue but have additional sessions.
|
||||||
|
@@ -383,6 +383,7 @@ type harnessCfg struct {
|
|||||||
remoteBalance lnwire.MilliSatoshi
|
remoteBalance lnwire.MilliSatoshi
|
||||||
policy wtpolicy.Policy
|
policy wtpolicy.Policy
|
||||||
noRegisterChan0 bool
|
noRegisterChan0 bool
|
||||||
|
noAckCreateSession bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
||||||
@@ -414,6 +415,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||||||
NewAddress: func() (btcutil.Address, error) {
|
NewAddress: func() (btcutil.Address, error) {
|
||||||
return addr, nil
|
return addr, nil
|
||||||
},
|
},
|
||||||
|
NoAckCreateSession: cfg.noAckCreateSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := wtserver.New(serverCfg)
|
server, err := wtserver.New(serverCfg)
|
||||||
@@ -432,6 +434,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||||||
},
|
},
|
||||||
DB: clientDB,
|
DB: clientDB,
|
||||||
AuthDial: mockNet.AuthDial,
|
AuthDial: mockNet.AuthDial,
|
||||||
|
SecretKeyRing: wtmock.NewSecretKeyRing(),
|
||||||
PrivateTower: towerAddr,
|
PrivateTower: towerAddr,
|
||||||
Policy: cfg.policy,
|
Policy: cfg.policy,
|
||||||
NewAddress: func() ([]byte, error) {
|
NewAddress: func() ([]byte, error) {
|
||||||
@@ -729,6 +732,36 @@ func (h *testHarness) waitServerUpdates(hints []wtdb.BreachHint,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assertUpdatesForPolicy queries the server db for matches using the provided
|
||||||
|
// breach hints, then asserts that each match has a session with the expected
|
||||||
|
// policy.
|
||||||
|
func (h *testHarness) assertUpdatesForPolicy(hints []wtdb.BreachHint,
|
||||||
|
expPolicy wtpolicy.Policy) {
|
||||||
|
|
||||||
|
// Query for matches on the provided hints.
|
||||||
|
matches, err := h.serverDB.QueryMatches(hints)
|
||||||
|
if err != nil {
|
||||||
|
h.t.Fatalf("unable to query for matches: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the number of matches is exactly the number of provided
|
||||||
|
// hints.
|
||||||
|
if len(matches) != len(hints) {
|
||||||
|
h.t.Fatalf("expected: %d matches, got: %d", len(hints),
|
||||||
|
len(matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that all of the matches correspond to a session with the
|
||||||
|
// expected policy.
|
||||||
|
for _, match := range matches {
|
||||||
|
matchPolicy := match.SessionInfo.Policy
|
||||||
|
if expPolicy != matchPolicy {
|
||||||
|
h.t.Fatalf("expected session to have policy: %v, "+
|
||||||
|
"got: %v", expPolicy, matchPolicy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
localBalance = lnwire.MilliSatoshi(100000000)
|
localBalance = lnwire.MilliSatoshi(100000000)
|
||||||
remoteBalance = lnwire.MilliSatoshi(200000000)
|
remoteBalance = lnwire.MilliSatoshi(200000000)
|
||||||
@@ -1098,6 +1131,119 @@ var clientTests = []clientTest{
|
|||||||
h.waitServerUpdates(hints, 10*time.Second)
|
h.waitServerUpdates(hints, 10*time.Second)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "create session no ack",
|
||||||
|
cfg: harnessCfg{
|
||||||
|
localBalance: localBalance,
|
||||||
|
remoteBalance: remoteBalance,
|
||||||
|
policy: wtpolicy.Policy{
|
||||||
|
BlobType: blob.TypeDefault,
|
||||||
|
MaxUpdates: 5,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
noAckCreateSession: true,
|
||||||
|
},
|
||||||
|
fn: func(h *testHarness) {
|
||||||
|
const (
|
||||||
|
chanID = 0
|
||||||
|
numUpdates = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate the retributions that will be backed up.
|
||||||
|
hints := h.advanceChannelN(chanID, numUpdates)
|
||||||
|
|
||||||
|
// Now, queue the retributions for backup.
|
||||||
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
|
// Since the client is unable to create a session, the
|
||||||
|
// server should have no updates.
|
||||||
|
h.waitServerUpdates(nil, time.Second)
|
||||||
|
|
||||||
|
// Force quit the client since it has queued backups.
|
||||||
|
h.client.ForceQuit()
|
||||||
|
|
||||||
|
// Restart the server and allow it to ack session
|
||||||
|
// creation.
|
||||||
|
h.server.Stop()
|
||||||
|
h.serverCfg.NoAckCreateSession = false
|
||||||
|
h.startServer()
|
||||||
|
defer h.server.Stop()
|
||||||
|
|
||||||
|
// Restart the client with the same policy, which will
|
||||||
|
// immediately try to overwrite the old session with an
|
||||||
|
// identical one.
|
||||||
|
h.startClient()
|
||||||
|
defer h.client.ForceQuit()
|
||||||
|
|
||||||
|
// Now, queue the retributions for backup.
|
||||||
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
|
// Wait for all of the updates to be populated in the
|
||||||
|
// server's database.
|
||||||
|
h.waitServerUpdates(hints, 5*time.Second)
|
||||||
|
|
||||||
|
// Assert that the server has updates for the clients
|
||||||
|
// most recent policy.
|
||||||
|
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create session no ack change policy",
|
||||||
|
cfg: harnessCfg{
|
||||||
|
localBalance: localBalance,
|
||||||
|
remoteBalance: remoteBalance,
|
||||||
|
policy: wtpolicy.Policy{
|
||||||
|
BlobType: blob.TypeDefault,
|
||||||
|
MaxUpdates: 5,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
noAckCreateSession: true,
|
||||||
|
},
|
||||||
|
fn: func(h *testHarness) {
|
||||||
|
const (
|
||||||
|
chanID = 0
|
||||||
|
numUpdates = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate the retributions that will be backed up.
|
||||||
|
hints := h.advanceChannelN(chanID, numUpdates)
|
||||||
|
|
||||||
|
// Now, queue the retributions for backup.
|
||||||
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
|
// Since the client is unable to create a session, the
|
||||||
|
// server should have no updates.
|
||||||
|
h.waitServerUpdates(nil, time.Second)
|
||||||
|
|
||||||
|
// Force quit the client since it has queued backups.
|
||||||
|
h.client.ForceQuit()
|
||||||
|
|
||||||
|
// Restart the server and allow it to ack session
|
||||||
|
// creation.
|
||||||
|
h.server.Stop()
|
||||||
|
h.serverCfg.NoAckCreateSession = false
|
||||||
|
h.startServer()
|
||||||
|
defer h.server.Stop()
|
||||||
|
|
||||||
|
// Restart the client with a new policy, which will
|
||||||
|
// immediately try to overwrite the prior session with
|
||||||
|
// the old policy.
|
||||||
|
h.clientCfg.Policy.SweepFeeRate = 2
|
||||||
|
h.startClient()
|
||||||
|
defer h.client.ForceQuit()
|
||||||
|
|
||||||
|
// Now, queue the retributions for backup.
|
||||||
|
h.backupStates(chanID, 0, numUpdates, nil)
|
||||||
|
|
||||||
|
// Wait for all of the updates to be populated in the
|
||||||
|
// server's database.
|
||||||
|
h.waitServerUpdates(hints, 5*time.Second)
|
||||||
|
|
||||||
|
// Assert that the server has updates for the clients
|
||||||
|
// most recent policy.
|
||||||
|
h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestClient executes the client test suite, asserting the ability to backup
|
// TestClient executes the client test suite, asserting the ability to backup
|
||||||
|
24
watchtower/wtclient/derivation.go
Normal file
24
watchtower/wtclient/derivation.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package wtclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeriveSessionKey accepts an session key index for an existing session and
|
||||||
|
// derives the HD private key to be used to authenticate the brontide transport
|
||||||
|
// and authenticate requests sent to the tower. The key will use the
|
||||||
|
// keychain.KeyFamilyTowerSession and the provided index, giving a BIP43
|
||||||
|
// derivation path of:
|
||||||
|
//
|
||||||
|
// * m/1017'/coinType'/8/0/index
|
||||||
|
func DeriveSessionKey(keyRing SecretKeyRing,
|
||||||
|
index uint32) (*btcec.PrivateKey, error) {
|
||||||
|
|
||||||
|
return keyRing.DerivePrivKey(keychain.KeyDescriptor{
|
||||||
|
KeyLocator: keychain.KeyLocator{
|
||||||
|
Family: keychain.KeyFamilyTowerSession,
|
||||||
|
Index: index,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/brontide"
|
"github.com/lightningnetwork/lnd/brontide"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||||
@@ -19,6 +20,17 @@ type DB interface {
|
|||||||
// sessions.
|
// sessions.
|
||||||
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
CreateTower(*lnwire.NetAddress) (*wtdb.Tower, error)
|
||||||
|
|
||||||
|
// LoadTower retrieves a tower by its tower ID.
|
||||||
|
LoadTower(uint64) (*wtdb.Tower, error)
|
||||||
|
|
||||||
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||||
|
// particular tower id. The index is reserved for that tower until
|
||||||
|
// CreateClientSession is invoked for that tower and index, at which
|
||||||
|
// point a new index for that tower can be reserved. Multiple calls to
|
||||||
|
// this method before CreateClientSession is invoked should return the
|
||||||
|
// same index.
|
||||||
|
NextSessionKeyIndex(uint64) (uint32, error)
|
||||||
|
|
||||||
// CreateClientSession saves a newly negotiated client session to the
|
// CreateClientSession saves a newly negotiated client session to the
|
||||||
// client's database. This enables the session to be used across
|
// client's database. This enables the session to be used across
|
||||||
// restarts.
|
// restarts.
|
||||||
@@ -74,3 +86,11 @@ func AuthDial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
|||||||
|
|
||||||
return brontide.Dial(localPriv, netAddr, dialer)
|
return brontide.Dial(localPriv, netAddr, dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SecretKeyRing abstracts the ability to derive HD private keys given a
|
||||||
|
// description of the derivation path.
|
||||||
|
type SecretKeyRing interface {
|
||||||
|
// DerivePrivKey derives the private key from the root seed using a
|
||||||
|
// key descriptor specifying the key's derivation path.
|
||||||
|
DerivePrivKey(loc keychain.KeyDescriptor) (*btcec.PrivateKey, error)
|
||||||
|
}
|
||||||
|
@@ -42,6 +42,10 @@ type NegotiatorConfig struct {
|
|||||||
// negotiated sessions.
|
// negotiated sessions.
|
||||||
DB DB
|
DB DB
|
||||||
|
|
||||||
|
// SecretKeyRing allows the client to derive new session private keys
|
||||||
|
// when attempting to negotiate session with a tower.
|
||||||
|
SecretKeyRing SecretKeyRing
|
||||||
|
|
||||||
// Candidates is an abstract set of tower candidates that the negotiator
|
// Candidates is an abstract set of tower candidates that the negotiator
|
||||||
// will traverse serially when attempting to negotiate a new session.
|
// will traverse serially when attempting to negotiate a new session.
|
||||||
Candidates TowerCandidateIterator
|
Candidates TowerCandidateIterator
|
||||||
@@ -224,7 +228,7 @@ func (n *sessionNegotiator) negotiate() {
|
|||||||
|
|
||||||
// On the first pass, initialize the backoff to our configured min
|
// On the first pass, initialize the backoff to our configured min
|
||||||
// backoff.
|
// backoff.
|
||||||
backoff := n.cfg.MinBackoff
|
var backoff time.Duration
|
||||||
|
|
||||||
retryWithBackoff:
|
retryWithBackoff:
|
||||||
// If we are retrying, wait out the delay before continuing.
|
// If we are retrying, wait out the delay before continuing.
|
||||||
@@ -240,14 +244,25 @@ retryWithBackoff:
|
|||||||
// iterator to ensure the results are fresh.
|
// iterator to ensure the results are fresh.
|
||||||
n.cfg.Candidates.Reset()
|
n.cfg.Candidates.Reset()
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-n.quit:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
// Pull the next candidate from our list of addresses.
|
// Pull the next candidate from our list of addresses.
|
||||||
tower, err := n.cfg.Candidates.Next()
|
tower, err := n.cfg.Candidates.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We've run out of addresses, double and clamp backoff.
|
if backoff == 0 {
|
||||||
|
backoff = n.cfg.MinBackoff
|
||||||
|
} else {
|
||||||
|
// We've run out of addresses, double and clamp
|
||||||
|
// backoff.
|
||||||
backoff *= 2
|
backoff *= 2
|
||||||
if backoff > n.cfg.MaxBackoff {
|
if backoff > n.cfg.MaxBackoff {
|
||||||
backoff = n.cfg.MaxBackoff
|
backoff = n.cfg.MaxBackoff
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("Unable to get new tower candidate, "+
|
log.Debugf("Unable to get new tower candidate, "+
|
||||||
"retrying after %v -- reason: %v", backoff, err)
|
"retrying after %v -- reason: %v", backoff, err)
|
||||||
@@ -255,12 +270,23 @@ retryWithBackoff:
|
|||||||
goto retryWithBackoff
|
goto retryWithBackoff
|
||||||
}
|
}
|
||||||
|
|
||||||
|
towerPub := tower.IdentityKey.SerializeCompressed()
|
||||||
log.Debugf("Attempting session negotiation with tower=%x",
|
log.Debugf("Attempting session negotiation with tower=%x",
|
||||||
tower.IdentityKey.SerializeCompressed())
|
towerPub)
|
||||||
|
|
||||||
|
// Before proceeding, we will reserve a session key index to use
|
||||||
|
// with this specific tower. If one is already reserved, the
|
||||||
|
// existing index will be returned.
|
||||||
|
keyIndex, err := n.cfg.DB.NextSessionKeyIndex(tower.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Unable to reserve session key index "+
|
||||||
|
"for tower=%x: %v", towerPub, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// We'll now attempt the CreateSession dance with the tower to
|
// We'll now attempt the CreateSession dance with the tower to
|
||||||
// get a new session, trying all addresses if necessary.
|
// get a new session, trying all addresses if necessary.
|
||||||
err = n.createSession(tower)
|
err = n.createSession(tower, keyIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Session negotiation with tower=%x "+
|
log.Debugf("Session negotiation with tower=%x "+
|
||||||
"failed, trying again -- reason: %v",
|
"failed, trying again -- reason: %v",
|
||||||
@@ -277,22 +303,21 @@ retryWithBackoff:
|
|||||||
// its stored addresses. This method returns after the first successful
|
// its stored addresses. This method returns after the first successful
|
||||||
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If
|
||||||
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
// the tower has no addresses, ErrNoTowerAddrs is returned.
|
||||||
func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
|
||||||
|
keyIndex uint32) error {
|
||||||
|
|
||||||
// If the tower has no addresses, there's nothing we can do.
|
// If the tower has no addresses, there's nothing we can do.
|
||||||
if len(tower.Addresses) == 0 {
|
if len(tower.Addresses) == 0 {
|
||||||
return ErrNoTowerAddrs
|
return ErrNoTowerAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(conner): create with hdkey at random index
|
sessionPriv, err := DeriveSessionKey(n.cfg.SecretKeyRing, keyIndex)
|
||||||
sessionPrivKey, err := btcec.NewPrivateKey(btcec.S256())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(conner): write towerAddr+privkey
|
|
||||||
|
|
||||||
for _, lnAddr := range tower.LNAddrs() {
|
for _, lnAddr := range tower.LNAddrs() {
|
||||||
err = n.tryAddress(sessionPrivKey, tower, lnAddr)
|
err = n.tryAddress(sessionPriv, keyIndex, tower, lnAddr)
|
||||||
switch {
|
switch {
|
||||||
case err == ErrPermanentTowerFailure:
|
case err == ErrPermanentTowerFailure:
|
||||||
// TODO(conner): report to iterator? can then be reset
|
// TODO(conner): report to iterator? can then be reset
|
||||||
@@ -318,7 +343,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower) error {
|
|||||||
// returns true if all steps succeed and the new session has been persisted, and
|
// returns true if all steps succeed and the new session has been persisted, and
|
||||||
// fails otherwise.
|
// fails otherwise.
|
||||||
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
||||||
tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
|
||||||
|
|
||||||
// Connect to the tower address using our generated session key.
|
// Connect to the tower address using our generated session key.
|
||||||
conn, err := n.cfg.Dial(privKey, lnAddr)
|
conn, err := n.cfg.Dial(privKey, lnAddr)
|
||||||
@@ -394,7 +419,8 @@ func (n *sessionNegotiator) tryAddress(privKey *btcec.PrivateKey,
|
|||||||
clientSession := &wtdb.ClientSession{
|
clientSession := &wtdb.ClientSession{
|
||||||
TowerID: tower.ID,
|
TowerID: tower.ID,
|
||||||
Tower: tower,
|
Tower: tower,
|
||||||
SessionPrivKey: privKey, // remove after using HD keys
|
KeyIndex: keyIndex,
|
||||||
|
SessionPrivKey: privKey,
|
||||||
ID: sessionID,
|
ID: sessionID,
|
||||||
Policy: n.cfg.Policy,
|
Policy: n.cfg.Policy,
|
||||||
SeqNum: 0,
|
SeqNum: 0,
|
||||||
|
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||||
)
|
)
|
||||||
@@ -30,6 +29,15 @@ var (
|
|||||||
// LastApplied value greater than any allocated sequence number.
|
// LastApplied value greater than any allocated sequence number.
|
||||||
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
ErrUnallocatedLastApplied = errors.New("tower echoed last appiled " +
|
||||||
"greater than allocated seqnum")
|
"greater than allocated seqnum")
|
||||||
|
|
||||||
|
// ErrNoReservedKeyIndex signals that a client session could not be
|
||||||
|
// created because no session key index was reserved.
|
||||||
|
ErrNoReservedKeyIndex = errors.New("key index not reserved")
|
||||||
|
|
||||||
|
// ErrIncorrectKeyIndex signals that the client session could not be
|
||||||
|
// created because session key index differs from the reserved key
|
||||||
|
// index.
|
||||||
|
ErrIncorrectKeyIndex = errors.New("incorrect key index")
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientSession encapsulates a SessionInfo returned from a successful
|
// ClientSession encapsulates a SessionInfo returned from a successful
|
||||||
@@ -57,14 +65,17 @@ type ClientSession struct {
|
|||||||
// tower with TowerID.
|
// tower with TowerID.
|
||||||
Tower *Tower
|
Tower *Tower
|
||||||
|
|
||||||
// SessionKeyDesc is the key descriptor used to derive the client's
|
// KeyIndex is the index of key locator used to derive the client's
|
||||||
// session key so that it can authenticate with the tower to update its
|
// session key so that it can authenticate with the tower to update its
|
||||||
// session.
|
// session. In order to rederive the private key, the key locator should
|
||||||
SessionKeyDesc keychain.KeyLocator
|
// use the keychain.KeyFamilyTowerSession key family.
|
||||||
|
KeyIndex uint32
|
||||||
|
|
||||||
// SessionPrivKey is the ephemeral secret key used to connect to the
|
// SessionPrivKey is the ephemeral secret key used to connect to the
|
||||||
// watchtower.
|
// watchtower.
|
||||||
// TODO(conner): remove after HD keys
|
//
|
||||||
|
// NOTE: This value is not serialized. It is derived using the KeyIndex
|
||||||
|
// on startup to avoid storing private keys on disk.
|
||||||
SessionPrivKey *btcec.PrivateKey
|
SessionPrivKey *btcec.PrivateKey
|
||||||
|
|
||||||
// Policy holds the negotiated session parameters.
|
// Policy holds the negotiated session parameters.
|
||||||
|
@@ -61,7 +61,8 @@ func (db *MockDB) InsertSessionInfo(info *SessionInfo) error {
|
|||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
defer db.mu.Unlock()
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
if _, ok := db.sessions[info.ID]; ok {
|
dbInfo, ok := db.sessions[info.ID]
|
||||||
|
if ok && dbInfo.LastApplied > 0 {
|
||||||
return ErrSessionAlreadyExists
|
return ErrSessionAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package wtdb
|
package wtdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -8,6 +9,12 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrTowerNotFound signals that the target tower was not found in the
|
||||||
|
// database.
|
||||||
|
ErrTowerNotFound = errors.New("tower not found")
|
||||||
|
)
|
||||||
|
|
||||||
// Tower holds the necessary components required to connect to a remote tower.
|
// Tower holds the necessary components required to connect to a remote tower.
|
||||||
// Communication is handled by brontide, and requires both a public key and an
|
// Communication is handled by brontide, and requires both a public key and an
|
||||||
// address.
|
// address.
|
||||||
|
@@ -22,6 +22,9 @@ type ClientDB struct {
|
|||||||
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
activeSessions map[wtdb.SessionID]*wtdb.ClientSession
|
||||||
towerIndex map[towerPK]uint64
|
towerIndex map[towerPK]uint64
|
||||||
towers map[uint64]*wtdb.Tower
|
towers map[uint64]*wtdb.Tower
|
||||||
|
|
||||||
|
nextIndex uint32
|
||||||
|
indexes map[uint64]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientDB initializes a new mock ClientDB.
|
// NewClientDB initializes a new mock ClientDB.
|
||||||
@@ -31,6 +34,7 @@ func NewClientDB() *ClientDB {
|
|||||||
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
activeSessions: make(map[wtdb.SessionID]*wtdb.ClientSession),
|
||||||
towerIndex: make(map[towerPK]uint64),
|
towerIndex: make(map[towerPK]uint64),
|
||||||
towers: make(map[uint64]*wtdb.Tower),
|
towers: make(map[uint64]*wtdb.Tower),
|
||||||
|
indexes: make(map[uint64]uint32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,6 +68,18 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
|||||||
return tower, nil
|
return tower, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadTower retrieves a tower by its tower ID.
|
||||||
|
func (m *ClientDB) LoadTower(towerID uint64) (*wtdb.Tower, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if tower, ok := m.towers[towerID]; ok {
|
||||||
|
return tower, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, wtdb.ErrTowerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||||
// backup. This allows the client to track which updates it should not attempt
|
// backup. This allows the client to track which updates it should not attempt
|
||||||
// to retry after startup.
|
// to retry after startup.
|
||||||
@@ -90,16 +106,29 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Ensure that a session key index has been reserved for this tower.
|
||||||
|
keyIndex, ok := m.indexes[session.TowerID]
|
||||||
|
if !ok {
|
||||||
|
return wtdb.ErrNoReservedKeyIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the session's index matches the reserved index.
|
||||||
|
if keyIndex != session.KeyIndex {
|
||||||
|
return wtdb.ErrIncorrectKeyIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the key index reservation for this tower. Once committed, this
|
||||||
|
// permits us to create another session with this tower.
|
||||||
|
delete(m.indexes, session.TowerID)
|
||||||
|
|
||||||
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
m.activeSessions[session.ID] = &wtdb.ClientSession{
|
||||||
TowerID: session.TowerID,
|
TowerID: session.TowerID,
|
||||||
Tower: session.Tower,
|
KeyIndex: session.KeyIndex,
|
||||||
SessionKeyDesc: session.SessionKeyDesc,
|
|
||||||
SessionPrivKey: session.SessionPrivKey,
|
|
||||||
ID: session.ID,
|
ID: session.ID,
|
||||||
Policy: session.Policy,
|
Policy: session.Policy,
|
||||||
SeqNum: session.SeqNum,
|
SeqNum: session.SeqNum,
|
||||||
TowerLastApplied: session.TowerLastApplied,
|
TowerLastApplied: session.TowerLastApplied,
|
||||||
RewardPkScript: session.RewardPkScript,
|
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||||
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
CommittedUpdates: make(map[uint16]*wtdb.CommittedUpdate),
|
||||||
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
||||||
}
|
}
|
||||||
@@ -107,6 +136,27 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||||
|
// particular tower id. The index is reserved for that tower until
|
||||||
|
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||||
|
// index for that tower can be reserved. Multiple calls to this method before
|
||||||
|
// CreateClientSession is invoked should return the same index.
|
||||||
|
func (m *ClientDB) NextSessionKeyIndex(towerID uint64) (uint32, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if index, ok := m.indexes[towerID]; ok {
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
index := m.nextIndex
|
||||||
|
m.indexes[towerID] = index
|
||||||
|
|
||||||
|
m.nextIndex++
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||||
// seqNum). This allows the client to retransmit this update on startup.
|
// seqNum). This allows the client to retransmit this update on startup.
|
||||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, seqNum uint16,
|
||||||
@@ -217,7 +267,12 @@ func (m *ClientDB) AddChanPkScript(chanID lnwire.ChannelID, pkScript []byte) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cloneBytes(b []byte) []byte {
|
func cloneBytes(b []byte) []byte {
|
||||||
|
if b == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
bb := make([]byte, len(b))
|
bb := make([]byte, len(b))
|
||||||
copy(bb, b)
|
copy(bb, b)
|
||||||
|
|
||||||
return bb
|
return bb
|
||||||
}
|
}
|
||||||
|
44
watchtower/wtmock/keyring.go
Normal file
44
watchtower/wtmock/keyring.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package wtmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecretKeyRing is a mock, in-memory implementation for deriving private keys.
|
||||||
|
type SecretKeyRing struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
keys map[keychain.KeyLocator]*btcec.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSecretKeyRing creates a new mock SecretKeyRing.
|
||||||
|
func NewSecretKeyRing() *SecretKeyRing {
|
||||||
|
return &SecretKeyRing{
|
||||||
|
keys: make(map[keychain.KeyLocator]*btcec.PrivateKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DerivePrivKey derives the private key for a given key descriptor. If
|
||||||
|
// this method is called twice with the same argument, it will return the same
|
||||||
|
// private key.
|
||||||
|
func (m *SecretKeyRing) DerivePrivKey(
|
||||||
|
desc keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if key, ok := m.keys[desc.KeyLocator]; ok {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, err := btcec.NewPrivateKey(btcec.S256())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.keys[desc.KeyLocator] = privKey
|
||||||
|
|
||||||
|
return privKey, nil
|
||||||
|
}
|
@@ -21,45 +21,26 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
existingInfo, err := s.cfg.DB.GetSessionInfo(id)
|
||||||
switch {
|
switch {
|
||||||
|
|
||||||
|
// We already have a session, though it is currently unused. We'll allow
|
||||||
|
// the client to recommit the session if it wanted to change the policy.
|
||||||
|
case err == nil && existingInfo.LastApplied == 0:
|
||||||
|
|
||||||
// We already have a session corresponding to this session id, return an
|
// We already have a session corresponding to this session id, return an
|
||||||
// error signaling that it already exists in our database. We return the
|
// error signaling that it already exists in our database. We return the
|
||||||
// reward address to the client in case they were not able to process
|
// reward address to the client in case they were not able to process
|
||||||
// our reply earlier.
|
// our reply earlier.
|
||||||
case err == nil:
|
case err == nil && existingInfo.LastApplied > 0:
|
||||||
log.Debugf("Already have session for %s", id)
|
log.Debugf("Already have session for %s", id)
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
peer, id, wtwire.CreateSessionCodeAlreadyExists,
|
||||||
existingInfo.RewardAddress,
|
existingInfo.LastApplied, existingInfo.RewardAddress,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Some other database error occurred, return a temporary failure.
|
// Some other database error occurred, return a temporary failure.
|
||||||
case err != wtdb.ErrSessionNotFound:
|
case err != wtdb.ErrSessionNotFound:
|
||||||
log.Errorf("unable to load session info for %s", id)
|
log.Errorf("unable to load session info for %s", id)
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we've established that this session does not exist in the
|
|
||||||
// database, retrieve the sweep address that will be given to the
|
|
||||||
// client. This address is to be included by the client when signing
|
|
||||||
// sweep transactions destined for this tower, if its negotiated output
|
|
||||||
// is not dust.
|
|
||||||
rewardAddress, err := s.cfg.NewAddress()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to generate reward addr for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct the pkscript the client should pay to when signing justice
|
|
||||||
// transactions for this session.
|
|
||||||
rewardScript, err := txscript.PayToAddrScript(rewardAddress)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to generate reward script for %s", id)
|
|
||||||
return s.replyCreateSession(
|
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,10 +49,39 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
|
log.Debugf("Rejecting CreateSession from %s, unsupported blob "+
|
||||||
"type %s", id, req.BlobType)
|
"type %s", id, req.BlobType)
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CreateSessionCodeRejectBlobType, nil,
|
peer, id, wtwire.CreateSessionCodeRejectBlobType, 0,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now that we've established that this session does not exist in the
|
||||||
|
// database, retrieve the sweep address that will be given to the
|
||||||
|
// client. This address is to be included by the client when signing
|
||||||
|
// sweep transactions destined for this tower, if its negotiated output
|
||||||
|
// is not dust.
|
||||||
|
var rewardScript []byte
|
||||||
|
if req.BlobType.Has(blob.FlagReward) {
|
||||||
|
rewardAddress, err := s.cfg.NewAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to generate reward addr for %s: %v",
|
||||||
|
id, err)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the pkscript the client should pay to when signing
|
||||||
|
// justice transactions for this session.
|
||||||
|
rewardScript, err = txscript.PayToAddrScript(rewardAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to generate reward script for "+
|
||||||
|
"%s: %v", id, err)
|
||||||
|
return s.replyCreateSession(
|
||||||
|
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(conner): create invoice for upfront payment
|
// TODO(conner): create invoice for upfront payment
|
||||||
|
|
||||||
// Assemble the session info using the agreed upon parameters, reward
|
// Assemble the session info using the agreed upon parameters, reward
|
||||||
@@ -94,14 +104,14 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create session for %s", id)
|
log.Errorf("unable to create session for %s", id)
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CodeTemporaryFailure, nil,
|
peer, id, wtwire.CodeTemporaryFailure, 0, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Accepted session for %s", id)
|
log.Infof("Accepted session for %s", id)
|
||||||
|
|
||||||
return s.replyCreateSession(
|
return s.replyCreateSession(
|
||||||
peer, id, wtwire.CodeOK, rewardScript,
|
peer, id, wtwire.CodeOK, 0, rewardScript,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,10 +120,18 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
// Otherwise, this method returns a connection error to ensure we don't continue
|
// Otherwise, this method returns a connection error to ensure we don't continue
|
||||||
// communication with the client.
|
// communication with the client.
|
||||||
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
||||||
code wtwire.ErrorCode, data []byte) error {
|
code wtwire.ErrorCode, lastApplied uint16, data []byte) error {
|
||||||
|
|
||||||
|
if s.cfg.NoAckCreateSession {
|
||||||
|
return &connFailure{
|
||||||
|
ID: *id,
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
msg := &wtwire.CreateSessionReply{
|
msg := &wtwire.CreateSessionReply{
|
||||||
Code: code,
|
Code: code,
|
||||||
|
LastApplied: lastApplied,
|
||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,6 +149,6 @@ func (s *Server) replyCreateSession(peer Peer, id *wtdb.SessionID,
|
|||||||
// disconnect the client.
|
// disconnect the client.
|
||||||
return &connFailure{
|
return &connFailure{
|
||||||
ID: *id,
|
ID: *id,
|
||||||
Code: uint16(code),
|
Code: code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -52,6 +52,6 @@ func (s *Server) replyDeleteSession(peer Peer, id *wtdb.SessionID,
|
|||||||
// disconnect the client.
|
// disconnect the client.
|
||||||
return &connFailure{
|
return &connFailure{
|
||||||
ID: *id,
|
ID: *id,
|
||||||
Code: uint16(code),
|
Code: code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -56,6 +56,10 @@ type Config struct {
|
|||||||
// ChainHash identifies the network that the server is watching.
|
// ChainHash identifies the network that the server is watching.
|
||||||
ChainHash chainhash.Hash
|
ChainHash chainhash.Hash
|
||||||
|
|
||||||
|
// NoAckCreateSession causes the server to not reply to create session
|
||||||
|
// requests, this should only be used for testing.
|
||||||
|
NoAckCreateSession bool
|
||||||
|
|
||||||
// NoAckUpdates causes the server to not acknowledge state updates, this
|
// NoAckUpdates causes the server to not acknowledge state updates, this
|
||||||
// should only be used for testing.
|
// should only be used for testing.
|
||||||
NoAckUpdates bool
|
NoAckUpdates bool
|
||||||
@@ -283,12 +287,12 @@ func (s *Server) handleClient(peer Peer) {
|
|||||||
// error code.
|
// error code.
|
||||||
type connFailure struct {
|
type connFailure struct {
|
||||||
ID wtdb.SessionID
|
ID wtdb.SessionID
|
||||||
Code uint16
|
Code wtwire.ErrorCode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error displays the SessionID and Code that caused the connection failure.
|
// Error displays the SessionID and Code that caused the connection failure.
|
||||||
func (f *connFailure) Error() string {
|
func (f *connFailure) Error() string {
|
||||||
return fmt.Sprintf("connection with %s failed with code=%v",
|
return fmt.Sprintf("connection with %s failed with code=%s",
|
||||||
f.ID, f.Code,
|
f.ID, f.Code,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@@ -29,6 +29,8 @@ var (
|
|||||||
addrScript, _ = txscript.PayToAddrScript(addr)
|
addrScript, _ = txscript.PayToAddrScript(addr)
|
||||||
|
|
||||||
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
|
||||||
|
|
||||||
|
rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type()
|
||||||
)
|
)
|
||||||
|
|
||||||
// randPubKey generates a new secp keypair, and returns the public key.
|
// randPubKey generates a new secp keypair, and returns the public key.
|
||||||
@@ -157,11 +159,12 @@ type createSessionTestCase struct {
|
|||||||
createMsg *wtwire.CreateSession
|
createMsg *wtwire.CreateSession
|
||||||
expReply *wtwire.CreateSessionReply
|
expReply *wtwire.CreateSessionReply
|
||||||
expDupReply *wtwire.CreateSessionReply
|
expDupReply *wtwire.CreateSessionReply
|
||||||
|
sendStateUpdate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var createSessionTests = []createSessionTestCase{
|
var createSessionTests = []createSessionTestCase{
|
||||||
{
|
{
|
||||||
name: "reject duplicate session create",
|
name: "duplicate session create",
|
||||||
initMsg: wtwire.NewInitMessage(
|
initMsg: wtwire.NewInitMessage(
|
||||||
lnwire.NewRawFeatureVector(),
|
lnwire.NewRawFeatureVector(),
|
||||||
testnetChainHash,
|
testnetChainHash,
|
||||||
@@ -175,10 +178,56 @@ var createSessionTests = []createSessionTestCase{
|
|||||||
},
|
},
|
||||||
expReply: &wtwire.CreateSessionReply{
|
expReply: &wtwire.CreateSessionReply{
|
||||||
Code: wtwire.CodeOK,
|
Code: wtwire.CodeOK,
|
||||||
Data: addrScript,
|
Data: []byte{},
|
||||||
|
},
|
||||||
|
expDupReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
Data: []byte{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate session create after use",
|
||||||
|
initMsg: wtwire.NewInitMessage(
|
||||||
|
lnwire.NewRawFeatureVector(),
|
||||||
|
testnetChainHash,
|
||||||
|
),
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobType: blob.TypeDefault,
|
||||||
|
MaxUpdates: 1000,
|
||||||
|
RewardBase: 0,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
expReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
Data: []byte{},
|
||||||
},
|
},
|
||||||
expDupReply: &wtwire.CreateSessionReply{
|
expDupReply: &wtwire.CreateSessionReply{
|
||||||
Code: wtwire.CreateSessionCodeAlreadyExists,
|
Code: wtwire.CreateSessionCodeAlreadyExists,
|
||||||
|
LastApplied: 1,
|
||||||
|
Data: []byte{},
|
||||||
|
},
|
||||||
|
sendStateUpdate: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate session create reward",
|
||||||
|
initMsg: wtwire.NewInitMessage(
|
||||||
|
lnwire.NewRawFeatureVector(),
|
||||||
|
testnetChainHash,
|
||||||
|
),
|
||||||
|
createMsg: &wtwire.CreateSession{
|
||||||
|
BlobType: rewardType,
|
||||||
|
MaxUpdates: 1000,
|
||||||
|
RewardBase: 0,
|
||||||
|
RewardRate: 0,
|
||||||
|
SweepFeeRate: 1,
|
||||||
|
},
|
||||||
|
expReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
|
Data: addrScript,
|
||||||
|
},
|
||||||
|
expDupReply: &wtwire.CreateSessionReply{
|
||||||
|
Code: wtwire.CodeOK,
|
||||||
Data: addrScript,
|
Data: addrScript,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -251,6 +300,18 @@ func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if test.sendStateUpdate {
|
||||||
|
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
|
connect(t, s, peer, test.initMsg, timeoutDuration)
|
||||||
|
update := &wtwire.StateUpdate{
|
||||||
|
SeqNum: 1,
|
||||||
|
IsComplete: 1,
|
||||||
|
}
|
||||||
|
sendMsg(t, update, peer, timeoutDuration)
|
||||||
|
|
||||||
|
assertConnClosed(t, peer, 2*timeoutDuration)
|
||||||
|
}
|
||||||
|
|
||||||
// Simulate a peer with the same session id connection to the server
|
// Simulate a peer with the same session id connection to the server
|
||||||
// again.
|
// again.
|
||||||
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
|
||||||
@@ -705,7 +766,7 @@ func TestServerDeleteSession(t *testing.T) {
|
|||||||
send: createSession,
|
send: createSession,
|
||||||
recv: &wtwire.CreateSessionReply{
|
recv: &wtwire.CreateSessionReply{
|
||||||
Code: wtwire.CodeOK,
|
Code: wtwire.CodeOK,
|
||||||
Data: addrScript,
|
Data: []byte{},
|
||||||
},
|
},
|
||||||
assert: func(t *testing.T) {
|
assert: func(t *testing.T) {
|
||||||
// Both peers should have sessions.
|
// Both peers should have sessions.
|
||||||
|
@@ -117,7 +117,7 @@ func (s *Server) handleStateUpdate(peer Peer, id *wtdb.SessionID,
|
|||||||
if s.cfg.NoAckUpdates {
|
if s.cfg.NoAckUpdates {
|
||||||
return &connFailure{
|
return &connFailure{
|
||||||
ID: *id,
|
ID: *id,
|
||||||
Code: uint16(failCode),
|
Code: failCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +152,6 @@ func (s *Server) replyStateUpdate(peer Peer, id *wtdb.SessionID,
|
|||||||
// disconnect the client.
|
// disconnect the client.
|
||||||
return &connFailure{
|
return &connFailure{
|
||||||
ID: *id,
|
ID: *id,
|
||||||
Code: uint16(code),
|
Code: code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -43,6 +43,12 @@ type CreateSessionReply struct {
|
|||||||
// Code will be non-zero if the watchtower rejected the session init.
|
// Code will be non-zero if the watchtower rejected the session init.
|
||||||
Code CreateSessionCode
|
Code CreateSessionCode
|
||||||
|
|
||||||
|
// LastApplied is the tower's last accepted sequence number for the
|
||||||
|
// session. This is useful when the session already exists but the
|
||||||
|
// client doesn't realize it's already used the session, such as after a
|
||||||
|
// restoration.
|
||||||
|
LastApplied uint16
|
||||||
|
|
||||||
// Data is a byte slice returned the caller of the message, and is to be
|
// Data is a byte slice returned the caller of the message, and is to be
|
||||||
// interpreted according to the error Code. When the response is
|
// interpreted according to the error Code. When the response is
|
||||||
// CreateSessionCodeOK, data encodes the reward address to be included in
|
// CreateSessionCodeOK, data encodes the reward address to be included in
|
||||||
@@ -63,6 +69,7 @@ var _ Message = (*CreateSessionReply)(nil)
|
|||||||
func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
|
func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
|
||||||
return ReadElements(r,
|
return ReadElements(r,
|
||||||
&m.Code,
|
&m.Code,
|
||||||
|
&m.LastApplied,
|
||||||
&m.Data,
|
&m.Data,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -74,6 +81,7 @@ func (m *CreateSessionReply) Decode(r io.Reader, pver uint32) error {
|
|||||||
func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error {
|
func (m *CreateSessionReply) Encode(w io.Writer, pver uint32) error {
|
||||||
return WriteElements(w,
|
return WriteElements(w,
|
||||||
m.Code,
|
m.Code,
|
||||||
|
m.LastApplied,
|
||||||
m.Data,
|
m.Data,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@@ -12,8 +12,6 @@ const (
|
|||||||
// client side, or that the tower had already deleted the session in a
|
// client side, or that the tower had already deleted the session in a
|
||||||
// prior request that the client may not have received.
|
// prior request that the client may not have received.
|
||||||
DeleteSessionCodeNotFound DeleteSessionCode = 80
|
DeleteSessionCodeNotFound DeleteSessionCode = 80
|
||||||
|
|
||||||
// TODO(conner): add String method after wtclient is merged
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DeleteSessionReply is a message sent in response to a client's DeleteSession
|
// DeleteSessionReply is a message sent in response to a client's DeleteSession
|
||||||
|
@@ -46,6 +46,8 @@ func (c ErrorCode) String() string {
|
|||||||
return "StateUpdateCodeMaxUpdatesExceeded"
|
return "StateUpdateCodeMaxUpdatesExceeded"
|
||||||
case StateUpdateCodeSeqNumOutOfOrder:
|
case StateUpdateCodeSeqNumOutOfOrder:
|
||||||
return "StateUpdateCodeSeqNumOutOfOrder"
|
return "StateUpdateCodeSeqNumOutOfOrder"
|
||||||
|
case DeleteSessionCodeNotFound:
|
||||||
|
return "DeleteSessionCodeNotFound"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("UnknownErrorCode: %d", c)
|
return fmt.Sprintf("UnknownErrorCode: %d", c)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user