diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 9751988ba..32ebe93ad 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -64,6 +64,11 @@ type DB interface { ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) + // GetClientSession loads the ClientSession with the given ID from the + // DB. + GetClientSession(wtdb.SessionID, + ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) + // FetchSessionCommittedUpdates retrieves the current set of un-acked // updates of the given session. FetchSessionCommittedUpdates(id *wtdb.SessionID) ( diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 4801ca1ab..eaf188470 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1009,6 +1009,36 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, return byteOrder.Uint32(keyIndexBytes), nil } +// GetClientSession loads the ClientSession with the given ID from the DB. +func (c *ClientDB) GetClientSession(id SessionID, + opts ...ClientSessionListOption) (*ClientSession, error) { + + var sess *ClientSession + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + session, err := c.getClientSession( + sessionsBkt, chanIDIndexBkt, id[:], nil, opts..., + ) + if err != nil { + return err + } + + sess = session + return nil + }, func() {}) + + return sess, err +} + // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 071ee7782..9d38c2da2 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -554,6 +554,36 @@ func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { return summaries, nil } +// GetClientSession loads the ClientSession with the given ID from the DB. +func (m *ClientDB) GetClientSession(id wtdb.SessionID, + opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { + + cfg := wtdb.NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } + + session, ok := m.activeSessions[id] + if !ok { + return nil, wtdb.ErrClientSessionNotFound + } + + if cfg.PerMaxHeight != nil { + for chanID, index := range m.ackedUpdates[session.ID] { + cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) + } + } + + if cfg.PerCommittedUpdate != nil { + for _, update := range m.committedUpdates[session.ID] { + update := update + cfg.PerCommittedUpdate(&session, &update) + } + } + + return &session, nil +} + // RegisterChannel registers a channel for use within the client database. For // now, all that is stored in the channel summary is the sweep pkscript that // we'd like any tower sweeps to pay into. In the future, this will be extended