From e150bb83d1d94928fc2ad8b9fd23ee8a5c4f337e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 4 Oct 2022 15:08:18 +0200 Subject: [PATCH] watchtower/wtdb: check tower exists on session create Before creating a new session, first check that the TowerID that the ClientSession is referencing refers to an existing tower. This is done to prevent the situation where RemoveTower is called right before CreateClientSession is called which would, before this commit, lead to the session being created with a tower ID that does not refer to any existing tower. --- watchtower/wtdb/client_db.go | 11 ++++++ watchtower/wtdb/client_db_test.go | 63 ++++++++++++++++++------------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index c868b949f..7f4d7d6f5 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -574,6 +574,11 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrUninitializedDB } + towers := tx.ReadBucket(cTowerBkt) + if towers == nil { + return ErrUninitializedDB + } + // Check that client session with this session id doesn't // already exist. existingSessionBytes := sessions.NestedReadWriteBucket( @@ -583,7 +588,13 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrClientSessionAlreadyExists } + // Ensure that a tower with the given ID actually exists in the + // DB. towerID := session.TowerID + if _, err := getTower(towers, towerID.Bytes()); err != nil { + return err + } + blobType := session.Policy.BlobType // Check that this tower has a reserved key index. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index b0e4af6a3..d4f1699c9 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -17,6 +17,9 @@ import ( "github.com/stretchr/testify/require" ) +// pseudoAddr is a fake network address to be used for testing purposes. +var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} + // clientDBInit is a closure used to initialize a wtclient.DB instance. type clientDBInit func(t *testing.T) wtclient.DB @@ -189,6 +192,21 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +// newTower is a helper function that creates a new tower with a randomly +// generated public key and inserts it into the client DB. +func (h *clientDBHarness) newTower() *wtdb.Tower { + h.t.Helper() + + pk, err := randPubKey() + require.NoError(h.t, err) + + // Insert a random tower into the database. + return h.createTower(&lnwire.NetAddress{ + IdentityKey: pk, + Address: pseudoAddr, + }, nil) +} + // testCreateClientSession asserts various conditions regarding the creation of // a new ClientSession. The test asserts: // - client sessions can only be created if a session key index is reserved. @@ -197,10 +215,12 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, func testCreateClientSession(h *clientDBHarness) { const blobType = blob.TypeAltruistAnchorCommit + tower := h.newTower() + // Create a test client session to insert. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -265,15 +285,12 @@ func testFilterClientSessions(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) for i := 0; i < numSessions; i++ { - towerID := wtdb.TowerID(1) - if i == numSessions-1 { - towerID = wtdb.TowerID(2) - } - keyIndex := h.nextKeyIndex(towerID, blobType) + tower := h.newTower() + keyIndex := h.nextKeyIndex(tower.ID, blobType) sessionID := wtdb.SessionID([33]byte{byte(i)}) h.insertSession(&wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: towerID, + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -285,8 +302,8 @@ func testFilterClientSessions(h *clientDBHarness) { }, ID: sessionID, }, nil) - towerSessions[towerID] = append( - towerSessions[towerID], sessionID, + towerSessions[tower.ID] = append( + towerSessions[tower.ID], sessionID, ) } @@ -311,19 +328,9 @@ func testCreateTower(h *clientDBHarness) { // Test that loading a tower with an arbitrary tower id fails. h.loadTowerByID(20, wtdb.ErrTowerNotFound) - pk, err := randPubKey() - if err != nil { - h.t.Fatalf("unable to generate pubkey: %v", err) - } - - addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} - lnAddr := &lnwire.NetAddress{ - IdentityKey: pk, - Address: addr1, - } - - // Insert a random tower into the database. - tower := h.createTower(lnAddr, nil) + tower := h.newTower() + require.Len(h.t, tower.LNAddrs(), 1) + towerAddr := tower.LNAddrs()[0] // Load the tower from the database and assert that it matches the tower // we created. @@ -335,7 +342,7 @@ func testCreateTower(h *clientDBHarness) { // Insert the address again into the database. Since the address is the // same, this should result in an unmodified tower record. - towerDupAddr := h.createTower(lnAddr, nil) + towerDupAddr := h.createTower(towerAddr, nil) require.Lenf(h.t, towerDupAddr.Addresses, 1, "duplicate address "+ "should be deduped") @@ -345,7 +352,7 @@ func testCreateTower(h *clientDBHarness) { addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} lnAddr2 := &lnwire.NetAddress{ - IdentityKey: pk, + IdentityKey: tower.IdentityKey, Address: addr2, } @@ -479,9 +486,11 @@ func testChanSummaries(h *clientDBHarness) { // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can func testCommitUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + + tower := h.newTower() session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType, @@ -570,10 +579,12 @@ func testCommitUpdate(h *clientDBHarness) { func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit + tower := h.newTower() + // Create a new session that the updates in this will be tied to. session := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ - TowerID: wtdb.TowerID(3), + TowerID: tower.ID, Policy: wtpolicy.Policy{ TxPolicy: wtpolicy.TxPolicy{ BlobType: blobType,