diff --git a/watchtower/wtdb/db_common.go b/watchtower/wtdb/db_common.go new file mode 100644 index 000000000..63bca9ae7 --- /dev/null +++ b/watchtower/wtdb/db_common.go @@ -0,0 +1,92 @@ +package wtdb + +import ( + "encoding/binary" + "errors" + "os" + "path/filepath" + + "github.com/coreos/bbolt" +) + +const ( + // dbFilePermission requests read+write access to the db file. + dbFilePermission = 0600 +) + +var ( + // metadataBkt stores all the meta information concerning the state of + // the database. + metadataBkt = []byte("metadata-bucket") + + // dbVersionKey is a static key used to retrieve the database version + // number from the metadataBkt. + dbVersionKey = []byte("version") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrNoDBVersion signals that the database contains no version info. + ErrNoDBVersion = errors.New("db has no version") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian +) + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// createDBIfNotExist opens the boltdb database at dbPath/name, creating one if +// one doesn't exist. The boolean returned indicates if the database did not +// exist before, or if it has been created but no version metadata exists within +// it. +func createDBIfNotExist(dbPath, name string) (*bbolt.DB, bool, error) { + path := filepath.Join(dbPath, name) + + // If the database file doesn't exist, this indicates we much initialize + // a fresh database with the latest version. + firstInit := !fileExists(path) + if firstInit { + // Ensure all parent directories are initialized. + err := os.MkdirAll(dbPath, 0700) + if err != nil { + return nil, false, err + } + } + + bdb, err := bbolt.Open(path, dbFilePermission, nil) + if err != nil { + return nil, false, err + } + + // If the file existed previously, we'll now check to see that the + // metadata bucket is properly initialized. It could be the case that + // the database was created, but we failed to actually populate any + // metadata. If the metadata bucket does not actually exist, we'll + // set firstInit to true so that we can treat is initialize the bucket. + if !firstInit { + var metadataExists bool + err = bdb.View(func(tx *bbolt.Tx) error { + metadataExists = tx.Bucket(metadataBkt) != nil + return nil + }) + if err != nil { + return nil, false, err + } + + if !metadataExists { + firstInit = true + } + } + + return bdb, firstInit, nil +} diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go index 0bcd271c0..96edafcad 100644 --- a/watchtower/wtdb/tower_db.go +++ b/watchtower/wtdb/tower_db.go @@ -2,23 +2,16 @@ package wtdb import ( "bytes" - "encoding/binary" "errors" - "os" - "path/filepath" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" ) const ( - // dbName is the filename of tower database. - dbName = "watchtower.db" - - // dbFilePermission requests read+write access to the db file. - dbFilePermission = 0600 + // towerDBName is the filename of tower database. + towerDBName = "watchtower.db" ) var ( @@ -49,26 +42,9 @@ var ( // epoch from the lookoutTipBkt. lookoutTipKey = []byte("lookout-tip") - // metadataBkt stores all the meta information concerning the state of - // the database. - metadataBkt = []byte("metadata-bucket") - - // dbVersionKey is a static key used to retrieve the database version - // number from the metadataBkt. - dbVersionKey = []byte("version") - - // ErrUninitializedDB signals that top-level buckets for the database - // have not been initialized. - ErrUninitializedDB = errors.New("tower db not initialized") - - // ErrNoDBVersion signals that the database contains no version info. - ErrNoDBVersion = errors.New("tower db has no version") - // ErrNoSessionHintIndex signals that an active session does not have an // initialized index for tracking its own state updates. ErrNoSessionHintIndex = errors.New("session hint index missing") - - byteOrder = binary.BigEndian ) // TowerDB is single database providing a persistent storage engine for the @@ -86,67 +62,20 @@ type TowerDB struct { // with a version number higher that the latest version will fail to prevent // accidental reversion. func OpenTowerDB(dbPath string) (*TowerDB, error) { - path := filepath.Join(dbPath, dbName) - - // If the database file doesn't exist, this indicates we much initialize - // a fresh database with the latest version. - firstInit := !fileExists(path) - if firstInit { - // Ensure all parent directories are initialized. - err := os.MkdirAll(dbPath, 0700) - if err != nil { - return nil, err - } - } - - bdb, err := bbolt.Open(path, dbFilePermission, nil) + bdb, firstInit, err := createDBIfNotExist(dbPath, towerDBName) if err != nil { return nil, err } - // If the file existed previously, we'll now check to see that the - // metadata bucket is properly initialized. It could be the case that - // the database was created, but we failed to actually populate any - // metadata. If the metadata bucket does not actually exist, we'll - // set firstInit to true so that we can treat is initialize the bucket. - if !firstInit { - var metadataExists bool - err = bdb.View(func(tx *bbolt.Tx) error { - metadataExists = tx.Bucket(metadataBkt) != nil - return nil - }) - if err != nil { - return nil, err - } - - if !metadataExists { - firstInit = true - } - } - towerDB := &TowerDB{ db: bdb, dbPath: dbPath, } - if firstInit { - // If the database has not yet been created, we'll initialize - // the database version with the latest known version. - err = towerDB.db.Update(func(tx *bbolt.Tx) error { - return initDBVersion(tx, getLatestDBVersion(dbVersions)) - }) - if err != nil { - bdb.Close() - return nil, err - } - } else { - // Otherwise, ensure that any migrations are applied to ensure - // the data is in the format expected by the latest version. - err = towerDB.syncVersions(dbVersions) - if err != nil { - bdb.Close() - return nil, err - } + err = initOrSyncVersions(towerDB, firstInit, towerDBVersions) + if err != nil { + bdb.Close() + return nil, err } // Now that the database version fully consistent with our latest known @@ -163,17 +92,6 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) { return towerDB, nil } -// fileExists returns true if the file exists, and false otherwise. -func fileExists(path string) bool { - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - return false - } - } - - return true -} - // initTowerDBBuckets creates all top-level buckets required to handle database // operations required by the latest version. func initTowerDBBuckets(tx *bbolt.Tx) error { @@ -194,53 +112,16 @@ func initTowerDBBuckets(tx *bbolt.Tx) error { return nil } -// syncVersions ensures the database version is consistent with the highest -// known database version, applying any migrations that have not been made. If -// the highest known version number is lower than the database's version, this -// method will fail to prevent accidental reversions. -func (t *TowerDB) syncVersions(versions []version) error { - curVersion, err := t.Version() - if err != nil { - return err - } - - latestVersion := getLatestDBVersion(versions) - switch { - - // Current version is higher than any known version, fail to prevent - // reversion. - case curVersion > latestVersion: - return channeldb.ErrDBReversion - - // Current version matches highest known version, nothing to do. - case curVersion == latestVersion: - return nil - } - - // Otherwise, apply any migrations in order to bring the database - // version up to the highest known version. - updates := getMigrations(versions, curVersion) - return t.db.Update(func(tx *bbolt.Tx) error { - for _, update := range updates { - if update.migration == nil { - continue - } - - log.Infof("Applying migration #%d", update.number) - - err := update.migration(tx) - if err != nil { - log.Errorf("Unable to apply migration #%d: %v", - err) - return err - } - } - - return putDBVersion(tx, latestVersion) - }) +// bdb returns the backing bbolt.DB instance. +// +// NOTE: Part of the versionedDB interface. +func (t *TowerDB) bdb() *bbolt.DB { + return t.db } // Version returns the database's current version number. +// +// NOTE: Part of the versionedDB interface. func (t *TowerDB) Version() (uint32, error) { var version uint32 err := t.db.View(func(tx *bbolt.Tx) error { diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index fd7481aff..974f25b06 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -1,6 +1,9 @@ package wtdb -import "github.com/coreos/bbolt" +import ( + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" +) // migration is a function which takes a prior outdated version of the database // instances and mutates the key/bucket structure to arrive at a more @@ -10,32 +13,25 @@ type migration func(tx *bbolt.Tx) error // version pairs a version number with the migration that would need to be // applied from the prior version to upgrade. type version struct { - number uint32 migration migration } -// dbVersions stores all versions and migrations of the database. This list will -// be used when opening the database to determine if any migrations must be -// applied. -var dbVersions = []version{ - { - // Initial version requires no migration. - number: 0, - migration: nil, - }, -} +// towerDBVersions stores all versions and migrations of the tower database. +// This list will be used when opening the database to determine if any +// migrations must be applied. +var towerDBVersions = []version{} // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 { - return versions[len(versions)-1].number + return uint32(len(versions)) } // getMigrations returns a slice of all updates with a greater number that // curVersion that need to be applied to sync up with the latest version. func getMigrations(versions []version, curVersion uint32) []version { var updates []version - for _, v := range versions { - if v.number > curVersion { + for i, v := range versions { + if uint32(i)+1 > curVersion { updates = append(updates, v) } } @@ -82,3 +78,81 @@ func putDBVersion(tx *bbolt.Tx, version uint32) error { byteOrder.PutUint32(versionBytes, version) return metadata.Put(dbVersionKey, versionBytes) } + +// versionedDB is a private interface implemented by both the tower and client +// databases, permitting all versioning operations to be performed generically +// on either. +type versionedDB interface { + // bdb returns the underlying bbolt database. + bdb() *bbolt.DB + + // Version returns the current version stored in the database. + Version() (uint32, error) +} + +// initOrSyncVersions ensures that the database version is properly set before +// opening the database up for regular use. When the database is being +// initialized for the first time, the caller should set init to true, which +// will simply write the latest version to the database. Otherwise, passing init +// as false will cause the database to apply any needed migrations to ensure its +// version matches the latest version in the provided versions list. +func initOrSyncVersions(db versionedDB, init bool, versions []version) error { + // If the database has not yet been created, we'll initialize the + // database version with the latest known version. + if init { + return db.bdb().Update(func(tx *bbolt.Tx) error { + return initDBVersion(tx, getLatestDBVersion(versions)) + }) + } + + // Otherwise, ensure that any migrations are applied to ensure the data + // is in the format expected by the latest version. + return syncVersions(db, versions) +} + +// syncVersions ensures the database version is consistent with the highest +// known database version, applying any migrations that have not been made. If +// the highest known version number is lower than the database's version, this +// method will fail to prevent accidental reversions. +func syncVersions(db versionedDB, versions []version) error { + curVersion, err := db.Version() + if err != nil { + return err + } + + latestVersion := getLatestDBVersion(versions) + switch { + + // Current version is higher than any known version, fail to prevent + // reversion. + case curVersion > latestVersion: + return channeldb.ErrDBReversion + + // Current version matches highest known version, nothing to do. + case curVersion == latestVersion: + return nil + } + + // Otherwise, apply any migrations in order to bring the database + // version up to the highest known version. + updates := getMigrations(versions, curVersion) + return db.bdb().Update(func(tx *bbolt.Tx) error { + for i, update := range updates { + if update.migration == nil { + continue + } + + version := curVersion + uint32(i) + 1 + log.Infof("Applying migration #%d", version) + + err := update.migration(tx) + if err != nil { + log.Errorf("Unable to apply migration #%d: %v", + version, err) + return err + } + } + + return putDBVersion(tx, latestVersion) + }) +}