watchtower: convert to use new kvdb abstraction

This commit is contained in:
Olaoluwa Osuntokun
2020-01-09 18:45:04 -08:00
parent 28bbaa2a94
commit 557b930c5f
4 changed files with 115 additions and 120 deletions

View File

@@ -8,7 +8,7 @@ import (
"net"
"github.com/btcsuite/btcd/btcec"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -113,7 +113,7 @@ var (
// ClientDB is single database providing a persistent storage engine for the
// wtclient.
type ClientDB struct {
db *bbolt.DB
db kvdb.Backend
dbPath string
}
@@ -146,7 +146,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error.
err = clientDB.db.Update(initClientDBBuckets)
err = kvdb.Update(clientDB.db, initClientDBBuckets)
if err != nil {
bdb.Close()
return nil, err
@@ -157,7 +157,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initClientDBBuckets creates all top-level buckets required to handle database
// operations required by the latest version.
func initClientDBBuckets(tx *bbolt.Tx) error {
func initClientDBBuckets(tx kvdb.RwTx) error {
buckets := [][]byte{
cSessionKeyIndexBkt,
cChanSummaryBkt,
@@ -167,7 +167,7 @@ func initClientDBBuckets(tx *bbolt.Tx) error {
}
for _, bucket := range buckets {
_, err := tx.CreateBucketIfNotExists(bucket)
_, err := tx.CreateTopLevelBucket(bucket)
if err != nil {
return err
}
@@ -179,7 +179,7 @@ func initClientDBBuckets(tx *bbolt.Tx) error {
// bdb returns the backing bbolt.DB instance.
//
// NOTE: Part of the versionedDB interface.
func (c *ClientDB) bdb() *bbolt.DB {
func (c *ClientDB) bdb() kvdb.Backend {
return c.db
}
@@ -188,7 +188,7 @@ func (c *ClientDB) bdb() *bbolt.DB {
// NOTE: Part of the versionedDB interface.
func (c *ClientDB) Version() (uint32, error) {
var version uint32
err := c.db.View(func(tx *bbolt.Tx) error {
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
var err error
version, err = getDBVersion(tx)
return err
@@ -215,13 +215,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
var tower *Tower
err := c.db.Update(func(tx *bbolt.Tx) error {
towerIndex := tx.Bucket(cTowerIndexBkt)
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
if towerIndex == nil {
return ErrUninitializedDB
}
towers := tx.Bucket(cTowerBkt)
towers := tx.ReadWriteBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
@@ -248,7 +248,7 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
//
// TODO(wilmer): with an index of tower -> sessions we
// can avoid the linear lookup.
sessions := tx.Bucket(cSessionBkt)
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
@@ -308,12 +308,12 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
//
// NOTE: An error is not returned if the tower doesn't exist.
func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return c.db.Update(func(tx *bbolt.Tx) error {
towers := tx.Bucket(cTowerBkt)
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
towers := tx.ReadWriteBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
towerIndex := tx.Bucket(cTowerIndexBkt)
towerIndex := tx.ReadWriteBucket(cTowerIndexBkt)
if towerIndex == nil {
return ErrUninitializedDB
}
@@ -342,7 +342,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
//
// TODO(wilmer): with an index of tower -> sessions we can avoid
// the linear lookup.
sessions := tx.Bucket(cSessionBkt)
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
@@ -383,8 +383,8 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// LoadTowerByID retrieves a tower by its tower ID.
func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var tower *Tower
err := c.db.View(func(tx *bbolt.Tx) error {
towers := tx.Bucket(cTowerBkt)
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
@@ -403,12 +403,12 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
// LoadTower retrieves a tower by its public key.
func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var tower *Tower
err := c.db.View(func(tx *bbolt.Tx) error {
towers := tx.Bucket(cTowerBkt)
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towers := tx.ReadBucket(cTowerBkt)
if towers == nil {
return ErrUninitializedDB
}
towerIndex := tx.Bucket(cTowerIndexBkt)
towerIndex := tx.ReadBucket(cTowerIndexBkt)
if towerIndex == nil {
return ErrUninitializedDB
}
@@ -432,8 +432,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
// ListTowers retrieves the list of towers available within the database.
func (c *ClientDB) ListTowers() ([]*Tower, error) {
var towers []*Tower
err := c.db.View(func(tx *bbolt.Tx) error {
towerBucket := tx.Bucket(cTowerBkt)
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
towerBucket := tx.ReadBucket(cTowerBkt)
if towerBucket == nil {
return ErrUninitializedDB
}
@@ -461,8 +461,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
// CreateClientSession is invoked should return the same index.
func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
var index uint32
err := c.db.Update(func(tx *bbolt.Tx) error {
keyIndex := tx.Bucket(cSessionKeyIndexBkt)
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
keyIndex := tx.ReadWriteBucket(cSessionKeyIndexBkt)
if keyIndex == nil {
return ErrUninitializedDB
}
@@ -509,20 +509,20 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
// CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID.
func (c *ClientDB) CreateClientSession(session *ClientSession) error {
return c.db.Update(func(tx *bbolt.Tx) error {
keyIndexes := tx.Bucket(cSessionKeyIndexBkt)
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
keyIndexes := tx.ReadWriteBucket(cSessionKeyIndexBkt)
if keyIndexes == nil {
return ErrUninitializedDB
}
sessions := tx.Bucket(cSessionBkt)
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
// Check that client session with this session id doesn't
// already exist.
existingSessionBytes := sessions.Bucket(session.ID[:])
existingSessionBytes := sessions.NestedReadWriteBucket(session.ID[:])
if existingSessionBytes != nil {
return ErrClientSessionAlreadyExists
}
@@ -558,8 +558,8 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
// response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession
err := c.db.View(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
@@ -577,7 +577,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
// listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower.
func listClientSessions(sessions *bbolt.Bucket,
func listClientSessions(sessions kvdb.ReadBucket,
id *TowerID) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession)
@@ -612,8 +612,8 @@ func listClientSessions(sessions *bbolt.Bucket,
// channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
err := c.db.View(func(tx *bbolt.Tx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt)
err := kvdb.View(c.db, func(tx kvdb.ReadTx) error {
chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil {
return ErrUninitializedDB
}
@@ -648,8 +648,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
sweepPkScript []byte) error {
return c.db.Update(func(tx *bbolt.Tx) error {
chanSummaries := tx.Bucket(cChanSummaryBkt)
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
chanSummaries := tx.ReadWriteBucket(cChanSummaryBkt)
if chanSummaries == nil {
return ErrUninitializedDB
}
@@ -692,8 +692,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
update *CommittedUpdate) (uint16, error) {
var lastApplied uint16
err := c.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
@@ -708,7 +708,7 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
}
// Can't fail if the above didn't fail.
sessionBkt := sessions.Bucket(id[:])
sessionBkt := sessions.NestedReadWriteBucket(id[:])
// Ensure the session commits sub-bucket is initialized.
sessionCommits, err := sessionBkt.CreateBucketIfNotExists(
@@ -796,8 +796,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
lastApplied uint16) error {
return c.db.Update(func(tx *bbolt.Tx) error {
sessions := tx.Bucket(cSessionBkt)
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
sessions := tx.ReadWriteBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
@@ -835,11 +835,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
}
// Can't fail because of getClientSession succeeded.
sessionBkt := sessions.Bucket(id[:])
sessionBkt := sessions.NestedReadWriteBucket(id[:])
// If the commits sub-bucket doesn't exist, there can't possibly
// be a corresponding committed update to remove.
sessionCommits := sessionBkt.Bucket(cSessionCommits)
sessionCommits := sessionBkt.NestedReadWriteBucket(cSessionCommits)
if sessionCommits == nil {
return ErrCommittedUpdateNotFound
}
@@ -894,10 +894,10 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// bucket corresponding to the serialized session id. This does not deserialize
// the CommittedUpdates or AckUpdates associated with the session. If the caller
// requires this info, use getClientSession.
func getClientSessionBody(sessions *bbolt.Bucket,
func getClientSessionBody(sessions kvdb.ReadBucket,
idBytes []byte) (*ClientSession, error) {
sessionBkt := sessions.Bucket(idBytes)
sessionBkt := sessions.NestedReadBucket(idBytes)
if sessionBkt == nil {
return nil, ErrClientSessionNotFound
}
@@ -922,7 +922,7 @@ func getClientSessionBody(sessions *bbolt.Bucket,
// getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates and AckUpdates in
// addition to the ClientSession's body.
func getClientSession(sessions *bbolt.Bucket,
func getClientSession(sessions kvdb.ReadBucket,
idBytes []byte) (*ClientSession, error) {
session, err := getClientSessionBody(sessions, idBytes)
@@ -950,17 +950,17 @@ func getClientSession(sessions *bbolt.Bucket,
// getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id.
func getClientSessionCommits(sessions *bbolt.Bucket,
func getClientSessionCommits(sessions kvdb.ReadBucket,
idBytes []byte) ([]CommittedUpdate, error) {
// Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes)
sessionBkt := sessions.NestedReadBucket(idBytes)
// Initialize commitedUpdates so that we can return an initialized map
// if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0)
sessionCommits := sessionBkt.Bucket(cSessionCommits)
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
if sessionCommits == nil {
return committedUpdates, nil
}
@@ -986,17 +986,17 @@ func getClientSessionCommits(sessions *bbolt.Bucket,
// getClientSessionAcks retrieves all acked updates for the session identified
// by the serialized session id.
func getClientSessionAcks(sessions *bbolt.Bucket,
func getClientSessionAcks(sessions kvdb.ReadBucket,
idBytes []byte) (map[uint16]BackupID, error) {
// Can't fail because client session body has already been read.
sessionBkt := sessions.Bucket(idBytes)
sessionBkt := sessions.NestedReadBucket(idBytes)
// Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist.
ackedUpdates := make(map[uint16]BackupID)
sessionAcks := sessionBkt.Bucket(cSessionAcks)
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
if sessionAcks == nil {
return ackedUpdates, nil
}
@@ -1023,7 +1023,7 @@ func getClientSessionAcks(sessions *bbolt.Bucket,
// putClientSessionBody stores the body of the ClientSession (everything but the
// CommittedUpdates and AckedUpdates).
func putClientSessionBody(sessions *bbolt.Bucket,
func putClientSessionBody(sessions kvdb.RwBucket,
session *ClientSession) error {
sessionBkt, err := sessions.CreateBucketIfNotExists(session.ID[:])
@@ -1042,7 +1042,7 @@ func putClientSessionBody(sessions *bbolt.Bucket,
// markSessionStatus updates the persisted state of the session to the new
// status.
func markSessionStatus(sessions *bbolt.Bucket, session *ClientSession,
func markSessionStatus(sessions kvdb.RwBucket, session *ClientSession,
status CSessionStatus) error {
session.Status = status
@@ -1050,7 +1050,7 @@ func markSessionStatus(sessions *bbolt.Bucket, session *ClientSession,
}
// getChanSummary loads a ClientChanSummary for the passed chanID.
func getChanSummary(chanSummaries *bbolt.Bucket,
func getChanSummary(chanSummaries kvdb.ReadBucket,
chanID lnwire.ChannelID) (*ClientChanSummary, error) {
chanSummaryBytes := chanSummaries.Get(chanID[:])
@@ -1068,7 +1068,7 @@ func getChanSummary(chanSummaries *bbolt.Bucket,
}
// putChanSummary stores a ClientChanSummary for the passed chanID.
func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
func putChanSummary(chanSummaries kvdb.RwBucket, chanID lnwire.ChannelID,
summary *ClientChanSummary) error {
var b bytes.Buffer
@@ -1081,7 +1081,7 @@ func putChanSummary(chanSummaries *bbolt.Bucket, chanID lnwire.ChannelID,
}
// getTower loads a Tower identified by its serialized tower id.
func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
func getTower(towers kvdb.ReadBucket, id []byte) (*Tower, error) {
towerBytes := towers.Get(id)
if towerBytes == nil {
return nil, ErrTowerNotFound
@@ -1099,7 +1099,7 @@ func getTower(towers *bbolt.Bucket, id []byte) (*Tower, error) {
}
// putTower stores a Tower identified by its serialized tower id.
func putTower(towers *bbolt.Bucket, tower *Tower) error {
func putTower(towers kvdb.RwBucket, tower *Tower) error {
var b bytes.Buffer
err := tower.Encode(&b)
if err != nil {