diff --git a/channeldb/db.go b/channeldb/db.go index a9aa0d124..8b373d8d7 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -209,6 +209,11 @@ var ( // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. byteOrder = binary.BigEndian + + // channelOpeningStateBucket is the database bucket used to store the + // channelOpeningState for each channel that is currently in the process + // of being opened. + channelOpeningStateBucket = []byte("channelOpeningState") ) // DB is the primary datastore for the lnd daemon. The database stores @@ -1197,6 +1202,56 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator) } +// SaveChannelOpeningState saves the serialized channel state for the provided +// chanPoint to the channelOpeningStateBucket. +func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { + return kvdb.Update(d, func(tx kvdb.RwTx) error { + bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) + if err != nil { + return err + } + + return bucket.Put(outPoint, serializedState) + }, func() {}) +} + +// GetChannelOpeningState fetches the serialized channel state for the provided +// outPoint from the database, or returns ErrChannelNotFound if the channel +// is not found. +func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { + var serializedState []byte + err := kvdb.View(d, func(tx kvdb.RTx) error { + bucket := tx.ReadBucket(channelOpeningStateBucket) + if bucket == nil { + // If the bucket does not exist, it means we never added + // a channel to the db, so return ErrChannelNotFound. + return ErrChannelNotFound + } + + serializedState = bucket.Get(outPoint) + if serializedState == nil { + return ErrChannelNotFound + } + + return nil + }, func() { + serializedState = nil + }) + return serializedState, err +} + +// DeleteChannelOpeningState removes any state for outPoint from the database. +func (d *DB) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(d, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket(channelOpeningStateBucket) + if bucket == nil { + return ErrChannelNotFound + } + + return bucket.Delete(outPoint) + }, func() {}) +} + // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. diff --git a/funding/manager.go b/funding/manager.go index f60039044..3daa393f1 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,7 +23,6 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" @@ -550,19 +549,6 @@ const ( addedToRouterGraph ) -var ( - // channelOpeningStateBucket is the database bucket used to store the - // channelOpeningState for each channel that is currently in the process - // of being opened. - channelOpeningStateBucket = []byte("channelOpeningState") - - // ErrChannelNotFound is an error returned when a channel is not known - // to us. In this case of the fundingManager, this error is returned - // when the channel in question is not considered being in an opening - // state. - ErrChannelNotFound = fmt.Errorf("channel not found") -) - // NewFundingManager creates and initializes a new instance of the // fundingManager. func NewFundingManager(cfg Config) (*Manager, error) { @@ -887,7 +873,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, channelState, shortChanID, err := f.getChannelOpeningState( &channel.FundingOutpoint, ) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Channel not in fundingManager's opening database, // meaning it was successfully announced to the // network. @@ -3539,26 +3525,20 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey { // chanPoint to the channelOpeningStateBucket. func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, state channelOpeningState, shortChanID *lnwire.ShortChannelID) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) - if err != nil { - return err - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - var outpointBytes bytes.Buffer - if err = WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - // Save state and the uint64 representation of the shortChanID - // for later use. - scratch := make([]byte, 10) - byteOrder.PutUint16(scratch[:2], uint16(state)) - byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) - - return bucket.Put(outpointBytes.Bytes(), scratch) - }, func() {}) + // Save state and the uint64 representation of the shortChanID + // for later use. + scratch := make([]byte, 10) + byteOrder.PutUint16(scratch[:2], uint16(state)) + byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) + return f.cfg.Wallet.Cfg.Database.SaveChannelOpeningState( + outpointBytes.Bytes(), scratch, + ) } // getChannelOpeningState fetches the channelOpeningState for the provided @@ -3567,51 +3547,31 @@ func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) ( channelOpeningState, *lnwire.ShortChannelID, error) { - var state channelOpeningState - var shortChanID lnwire.ShortChannelID - err := kvdb.View(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RTx) error { + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return 0, nil, err + } - bucket := tx.ReadBucket(channelOpeningStateBucket) - if bucket == nil { - // If the bucket does not exist, it means we never added - // a channel to the db, so return ErrChannelNotFound. - return ErrChannelNotFound - } - - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - value := bucket.Get(outpointBytes.Bytes()) - if value == nil { - return ErrChannelNotFound - } - - state = channelOpeningState(byteOrder.Uint16(value[:2])) - shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) - return nil - }, func() {}) + value, err := f.cfg.Wallet.Cfg.Database.GetChannelOpeningState( + outpointBytes.Bytes(), + ) if err != nil { return 0, nil, err } + state := channelOpeningState(byteOrder.Uint16(value[:2])) + shortChanID := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) return state, &shortChanID, nil } // deleteChannelOpeningState removes any state for chanPoint from the database. func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - bucket := tx.ReadWriteBucket(channelOpeningStateBucket) - if bucket == nil { - return fmt.Errorf("bucket not found") - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - return bucket.Delete(outpointBytes.Bytes()) - }, func() {}) + return f.cfg.Wallet.Cfg.Database.DeleteChannelOpeningState( + outpointBytes.Bytes(), + ) } diff --git a/funding/manager_test.go b/funding/manager_test.go index 97ee699f2..acd7ca514 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -922,12 +922,12 @@ func assertDatabaseState(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err != nil && err != ErrChannelNotFound { + if err != nil && err != channeldb.ErrChannelNotFound { t.Fatalf("unable to get channel state: %v", err) } // If we found the channel, check if it had the expected state. - if err != ErrChannelNotFound && state == expectedState { + if err != channeldb.ErrChannelNotFound && state == expectedState { // Got expected state, return with success. return } @@ -1165,7 +1165,7 @@ func assertErrChannelNotFound(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Got expected state, return with success. return } else if err != nil {