watchtower: always populate Tower in ClientSession

In this commit, we make sure to always populate the Tower member of a
ClientSession. This is done for consistency.
This commit is contained in:
Elle Mouton
2022-10-04 15:18:40 +02:00
parent e150bb83d1
commit c60ecaccbf
3 changed files with 28 additions and 18 deletions

View File

@@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(
sessions, &towerID,
sessions, towers, &towerID,
)
if err != nil {
return err
@@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB
}
towerID := TowerIDFromBytes(towerIDBytes)
towerSessions, err := listClientSessions(sessions, &towerID)
towerSessions, err := listClientSessions(
sessions, towers, &towerID,
)
if err != nil {
return err
}
@@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
if sessions == nil {
return ErrUninitializedDB
}
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
var err error
clientSessions, err = listClientSessions(sessions, id)
clientSessions, err = listClientSessions(sessions, towers, id)
return err
}, func() {
clientSessions = nil
@@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func listClientSessions(sessions kvdb.RBucket,
func listClientSessions(sessions, towers kvdb.RBucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
@@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height
// for each channel.
session, err := getClientSession(sessions, k)
session, err := getClientSession(sessions, towers, k)
if err != nil {
return err
}
@@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// getClientSessionBody loads the body of a ClientSession from the sessions
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession.
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
// If the caller requires this info, use getClientSession.
func getClientSessionBody(sessions kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
@@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
}
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body.
func getClientSession(sessions kvdb.RBucket,
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket,
idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes)
@@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
// Fetch the tower associated with this session.
tower, err := getTower(towers, session.TowerID.Bytes())
if err != nil {
return nil, err
}
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil {
@@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket,
return nil, err
}
session.Tower = tower
session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates