watchtower/wtdb: check tower exists on session create

Before creating a new session, first check that the TowerID that the
ClientSession is referencing refers to an existing tower. This is done
to prevent the situation where RemoveTower is called right before
CreateClientSession is called which would, before this commit, lead to
the session being created with a tower ID that does not refer to any
existing tower.
This commit is contained in:
Elle Mouton 2022-10-04 15:08:18 +02:00
parent 5dabf7cb3e
commit e150bb83d1
No known key found for this signature in database
GPG Key ID: D7D916376026F177
2 changed files with 48 additions and 26 deletions

View File

@ -574,6 +574,11 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
// Check that client session with this session id doesn't
// already exist.
existingSessionBytes := sessions.NestedReadWriteBucket(
@ -583,7 +588,13 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return ErrClientSessionAlreadyExists
}
// Ensure that a tower with the given ID actually exists in the
// DB.
towerID := session.TowerID
if _, err := getTower(towers, towerID.Bytes()); err != nil {
return err
}
blobType := session.Policy.BlobType
// Check that this tower has a reserved key index.

View File

@ -17,6 +17,9 @@ import (
"github.com/stretchr/testify/require"
)
// pseudoAddr is a fake network address to be used for testing purposes.
var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
// clientDBInit is a closure used to initialize a wtclient.DB instance.
type clientDBInit func(t *testing.T) wtclient.DB
@ -189,6 +192,21 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
require.ErrorIs(h.t, err, expErr)
}
// newTower is a helper function that creates a new tower with a randomly
// generated public key and inserts it into the client DB.
func (h *clientDBHarness) newTower() *wtdb.Tower {
h.t.Helper()
pk, err := randPubKey()
require.NoError(h.t, err)
// Insert a random tower into the database.
return h.createTower(&lnwire.NetAddress{
IdentityKey: pk,
Address: pseudoAddr,
}, nil)
}
// testCreateClientSession asserts various conditions regarding the creation of
// a new ClientSession. The test asserts:
// - client sessions can only be created if a session key index is reserved.
@ -197,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
func testCreateClientSession(h *clientDBHarness) {
const blobType = blob.TypeAltruistAnchorCommit
tower := h.newTower()
// Create a test client session to insert.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -265,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
for i := 0; i < numSessions; i++ {
towerID := wtdb.TowerID(1)
if i == numSessions-1 {
towerID = wtdb.TowerID(2)
}
keyIndex := h.nextKeyIndex(towerID, blobType)
tower := h.newTower()
keyIndex := h.nextKeyIndex(tower.ID, blobType)
sessionID := wtdb.SessionID([33]byte{byte(i)})
h.insertSession(&wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: towerID,
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -285,8 +302,8 @@ func testFilterClientSessions(h *clientDBHarness) {
},
ID: sessionID,
}, nil)
towerSessions[towerID] = append(
towerSessions[towerID], sessionID,
towerSessions[tower.ID] = append(
towerSessions[tower.ID], sessionID,
)
}
@ -311,19 +328,9 @@ func testCreateTower(h *clientDBHarness) {
// Test that loading a tower with an arbitrary tower id fails.
h.loadTowerByID(20, wtdb.ErrTowerNotFound)
pk, err := randPubKey()
if err != nil {
h.t.Fatalf("unable to generate pubkey: %v", err)
}
addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr := &lnwire.NetAddress{
IdentityKey: pk,
Address: addr1,
}
// Insert a random tower into the database.
tower := h.createTower(lnAddr, nil)
tower := h.newTower()
require.Len(h.t, tower.LNAddrs(), 1)
towerAddr := tower.LNAddrs()[0]
// Load the tower from the database and assert that it matches the tower
// we created.
@ -335,7 +342,7 @@ func testCreateTower(h *clientDBHarness) {
// Insert the address again into the database. Since the address is the
// same, this should result in an unmodified tower record.
towerDupAddr := h.createTower(lnAddr, nil)
towerDupAddr := h.createTower(towerAddr, nil)
require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+
"should be deduped")
@ -345,7 +352,7 @@ func testCreateTower(h *clientDBHarness) {
addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
lnAddr2 := &lnwire.NetAddress{
IdentityKey: pk,
IdentityKey: tower.IdentityKey,
Address: addr2,
}
@ -479,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) {
// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
func testCommitUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
tower := h.newTower()
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,
@ -570,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) {
func testAckUpdate(h *clientDBHarness) {
const blobType = blob.TypeAltruistCommit
tower := h.newTower()
// Create a new session that the updates in this will be tied to.
session := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{
TowerID: wtdb.TowerID(3),
TowerID: tower.ID,
Policy: wtpolicy.Policy{
TxPolicy: wtpolicy.TxPolicy{
BlobType: blobType,