mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-25 16:23:49 +02:00
watchtower: always populate Tower in ClientSession
In this commit, we make sure to always populate the Tower member of a ClientSession. This is done for consistency.
This commit is contained in:
parent
e150bb83d1
commit
c60ecaccbf
@ -354,8 +354,8 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
// optional filter can be provided to filter out any undesired client sessions.
|
// optional filter can be provided to filter out any undesired client sessions.
|
||||||
//
|
//
|
||||||
// NOTE: This method should only be used when deserialization of a
|
// NOTE: This method should only be used when deserialization of a
|
||||||
// ClientSession's Tower and SessionPrivKey fields is desired, otherwise, the
|
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||||
// existing ListClientSessions method should be used.
|
// ListClientSessions method should be used.
|
||||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||||
passesFilter func(*wtdb.ClientSession) bool) (
|
passesFilter func(*wtdb.ClientSession) bool) (
|
||||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||||
@ -371,12 +371,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
|||||||
// requests. This prevents us from having to store the private keys on
|
// requests. This prevents us from having to store the private keys on
|
||||||
// disk.
|
// disk.
|
||||||
for _, s := range sessions {
|
for _, s := range sessions {
|
||||||
tower, err := db.LoadTowerByID(s.TowerID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.Tower = tower
|
|
||||||
|
|
||||||
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
|
||||||
Family: keychain.KeyFamilyTowerSession,
|
Family: keychain.KeyFamilyTowerSession,
|
||||||
Index: s.KeyIndex,
|
Index: s.KeyIndex,
|
||||||
|
@ -288,7 +288,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
|
|||||||
}
|
}
|
||||||
towerID := TowerIDFromBytes(towerIDBytes)
|
towerID := TowerIDFromBytes(towerIDBytes)
|
||||||
towerSessions, err := listClientSessions(
|
towerSessions, err := listClientSessions(
|
||||||
sessions, &towerID,
|
sessions, towers, &towerID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -389,7 +389,9 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
|||||||
return ErrUninitializedDB
|
return ErrUninitializedDB
|
||||||
}
|
}
|
||||||
towerID := TowerIDFromBytes(towerIDBytes)
|
towerID := TowerIDFromBytes(towerIDBytes)
|
||||||
towerSessions, err := listClientSessions(sessions, &towerID)
|
towerSessions, err := listClientSessions(
|
||||||
|
sessions, towers, &towerID,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -685,8 +687,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
|||||||
if sessions == nil {
|
if sessions == nil {
|
||||||
return ErrUninitializedDB
|
return ErrUninitializedDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
towers := tx.ReadBucket(cTowerBkt)
|
||||||
|
if towers == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
clientSessions, err = listClientSessions(sessions, id)
|
clientSessions, err = listClientSessions(sessions, towers, id)
|
||||||
return err
|
return err
|
||||||
}, func() {
|
}, func() {
|
||||||
clientSessions = nil
|
clientSessions = nil
|
||||||
@ -701,7 +709,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
|||||||
// listClientSessions returns the set of all client sessions known to the db. An
|
// listClientSessions returns the set of all client sessions known to the db. An
|
||||||
// optional tower ID can be used to filter out any client sessions in the
|
// optional tower ID can be used to filter out any client sessions in the
|
||||||
// response that do not correspond to this tower.
|
// response that do not correspond to this tower.
|
||||||
func listClientSessions(sessions kvdb.RBucket,
|
func listClientSessions(sessions, towers kvdb.RBucket,
|
||||||
id *TowerID) (map[SessionID]*ClientSession, error) {
|
id *TowerID) (map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
clientSessions := make(map[SessionID]*ClientSession)
|
clientSessions := make(map[SessionID]*ClientSession)
|
||||||
@ -710,7 +718,7 @@ func listClientSessions(sessions kvdb.RBucket,
|
|||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := getClientSession(sessions, k)
|
session, err := getClientSession(sessions, towers, k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1022,8 +1030,8 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||||||
|
|
||||||
// getClientSessionBody loads the body of a ClientSession from the sessions
|
// getClientSessionBody loads the body of a ClientSession from the sessions
|
||||||
// bucket corresponding to the serialized session id. This does not deserialize
|
// bucket corresponding to the serialized session id. This does not deserialize
|
||||||
// the CommittedUpdates or AckUpdates associated with the session. If the caller
|
// the CommittedUpdates, AckUpdates or the Tower associated with the session.
|
||||||
// requires this info, use getClientSession.
|
// If the caller requires this info, use getClientSession.
|
||||||
func getClientSessionBody(sessions kvdb.RBucket,
|
func getClientSessionBody(sessions kvdb.RBucket,
|
||||||
idBytes []byte) (*ClientSession, error) {
|
idBytes []byte) (*ClientSession, error) {
|
||||||
|
|
||||||
@ -1050,9 +1058,9 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getClientSession loads the full ClientSession associated with the serialized
|
// getClientSession loads the full ClientSession associated with the serialized
|
||||||
// session id. This method populates the CommittedUpdates and AckUpdates in
|
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||||
// addition to the ClientSession's body.
|
// in addition to the ClientSession's body.
|
||||||
func getClientSession(sessions kvdb.RBucket,
|
func getClientSession(sessions, towers kvdb.RBucket,
|
||||||
idBytes []byte) (*ClientSession, error) {
|
idBytes []byte) (*ClientSession, error) {
|
||||||
|
|
||||||
session, err := getClientSessionBody(sessions, idBytes)
|
session, err := getClientSessionBody(sessions, idBytes)
|
||||||
@ -1060,6 +1068,12 @@ func getClientSession(sessions kvdb.RBucket,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch the tower associated with this session.
|
||||||
|
tower, err := getTower(towers, session.TowerID.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Fetch the committed updates for this session.
|
// Fetch the committed updates for this session.
|
||||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1072,6 +1086,7 @@ func getClientSession(sessions kvdb.RBucket,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session.Tower = tower
|
||||||
session.CommittedUpdates = commitedUpdates
|
session.CommittedUpdates = commitedUpdates
|
||||||
session.AckedUpdates = ackedUpdates
|
session.AckedUpdates = ackedUpdates
|
||||||
|
|
||||||
|
@ -220,6 +220,7 @@ func (m *ClientDB) listClientSessions(
|
|||||||
if tower != nil && *tower != session.TowerID {
|
if tower != nil && *tower != session.TowerID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
session.Tower = m.towers[session.TowerID]
|
||||||
sessions[session.ID] = &session
|
sessions[session.ID] = &session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user