diff --git a/autopilot/graph.go b/autopilot/graph.go index 2624aa79d..e630f8d35 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -148,7 +148,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, err } - dbNode, err := d.db.FetchLightningNode(nil, vertex) + dbNode, err := d.db.FetchLightningNode(vertex) switch { case err == channeldb.ErrGraphNodeNotFound: fallthrough diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 035bdf647..c7ad2a8d5 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -75,7 +75,7 @@ type Config struct { // ChanStateDB is a pointer to the database that stores the channel // state. - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // BlockCacheSize is the size (in bytes) of blocks kept in memory. BlockCacheSize uint64 diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 076e2fcf6..dce1210b2 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -21,7 +21,11 @@ type LiveChannelSource interface { // passed chanPoint. Optionally an existing db tx can be supplied. FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( *channeldb.OpenChannel, error) +} +// AddressSource is an interface that allows us to query for the set of +// addresses a node can be connected to. +type AddressSource interface { // AddrsForNode returns all known addresses for the target node public // key. AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) @@ -31,15 +35,15 @@ type LiveChannelSource interface { // passed open channel. The backup includes all information required to restore // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. -func assembleChanBackup(chanSource LiveChannelSource, +func assembleChanBackup(addrSource AddressSource, openChan *channeldb.OpenChannel) (*Single, error) { log.Debugf("Crafting backup for ChannelPoint(%v)", openChan.FundingOutpoint) // First, we'll query the channel source to obtain all the addresses - // that are are associated with the peer for this channel. - nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub) + // that are associated with the peer for this channel. + nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) if err != nil { return nil, err } @@ -52,8 +56,8 @@ func assembleChanBackup(chanSource LiveChannelSource, // FetchBackupForChan attempts to create a plaintext static channel backup for // the target channel identified by its channel point. If we're unable to find // the target channel, then an error will be returned. -func FetchBackupForChan(chanPoint wire.OutPoint, - chanSource LiveChannelSource) (*Single, error) { +func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, + addrSource AddressSource) (*Single, error) { // First, we'll query the channel source to see if the channel is known // and open within the database. @@ -66,7 +70,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, // Once we have the target channel, we can assemble the backup using // the source to obtain any extra information that we may need. - staticChanBackup, err := assembleChanBackup(chanSource, targetChan) + staticChanBackup, err := assembleChanBackup(addrSource, targetChan) if err != nil { return nil, fmt.Errorf("unable to create chan backup: %v", err) } @@ -76,7 +80,9 @@ func FetchBackupForChan(chanPoint wire.OutPoint, // FetchStaticChanBackups will return a plaintext static channel back up for // all known active/open channels within the passed channel source. -func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { +func FetchStaticChanBackups(chanSource LiveChannelSource, + addrSource AddressSource) ([]Single, error) { + // First, we'll query the backup source for information concerning all // currently open and available channels. openChans, err := chanSource.FetchAllChannels() @@ -89,7 +95,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { // channel. staticChanBackups := make([]Single, 0, len(openChans)) for _, openChan := range openChans { - chanBackup, err := assembleChanBackup(chanSource, openChan) + chanBackup, err := assembleChanBackup(addrSource, openChan) if err != nil { return nil, err } diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index e718dce3e..ff321c188 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -124,7 +124,9 @@ func TestFetchBackupForChan(t *testing.T) { }, } for i, testCase := range testCases { - _, err := FetchBackupForChan(testCase.chanPoint, chanSource) + _, err := FetchBackupForChan( + testCase.chanPoint, chanSource, chanSource, + ) switch { // If this is a valid test case, and we failed, then we'll // return an error. @@ -167,7 +169,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // With the channel source populated, we'll now attempt to create a set // of backups for all the channels. This should succeed, as all items // are populated within the channel source. - backups, err := FetchStaticChanBackups(chanSource) + backups, err := FetchStaticChanBackups(chanSource, chanSource) if err != nil { t.Fatalf("unable to create chan back ups: %v", err) } @@ -184,7 +186,7 @@ func TestFetchStaticChanBackups(t *testing.T) { copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) delete(chanSource.addrs, n) - _, err = FetchStaticChanBackups(chanSource) + _, err = FetchStaticChanBackups(chanSource, chanSource) if err == nil { t.Fatalf("query with incomplete information should fail") } @@ -193,7 +195,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // source at all, then we'll fail as well. chanSource = newMockChannelSource() chanSource.failQuery = true - _, err = FetchStaticChanBackups(chanSource) + _, err = FetchStaticChanBackups(chanSource, chanSource) if err == nil { t.Fatalf("query should fail") } diff --git a/channeldb/channel.go b/channeldb/channel.go index 31873ae3f..858cead91 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -729,7 +729,7 @@ type OpenChannel struct { RevocationKeyLocator keychain.KeyLocator // TODO(roasbeef): eww - Db *DB + Db *ChannelStateDB // TODO(roasbeef): just need to store local and remote HTLC's? @@ -800,7 +800,7 @@ func (c *OpenChannel) RefreshShortChanID() error { c.Lock() defer c.Unlock() - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) { - readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash) - if err != nil { - return nil, err + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists } - return readBucket.(kvdb.RwBucket), nil + // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like + // CreateIfNotExists, will return error + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil } // fullSync syncs the contents of an OpenChannel while re-using an existing @@ -964,8 +995,8 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.Lock() defer c.Unlock() - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucket( + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { + chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { channel.IsPending = false channel.ShortChannelID = openLoc - return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) + return putOpenChannel(chanBucket, channel) }, func() {}); err != nil { return err } @@ -1016,7 +1047,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { var commitPoint *btcec.PublicKey - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1240,7 +1271,7 @@ func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) { func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { var closeTx *wire.MsgTx - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1274,7 +1305,7 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { func (c *OpenChannel) putChanStatus(status ChannelStatus, fs ...func(kvdb.RwBucket) error) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1318,7 +1349,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus, } func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1442,7 +1473,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { c.FundingBroadcastHeight = pendingHeight - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return syncNewChannel(tx, c, []net.Addr{addr}) }, func() {}) } @@ -1470,7 +1501,10 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error { // Next, we need to establish a (possibly) new LinkNode relationship // for this channel. The LinkNode metadata contains reachability, // up-time, and service bits related information. - linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) + linkNode := NewLinkNode( + &LinkNodeDB{backend: c.Db.backend}, + wire.MainNet, c.IdentityPub, addrs..., + ) // TODO(roasbeef): do away with link node all together? @@ -1498,7 +1532,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, return ErrNoRestoredChannelMutation } - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2090,7 +2124,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { return ErrNoRestoredChannelMutation } - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { // First, we'll grab the writable bucket where this channel's // data resides. chanBucket, err := fetchChanBucketRw( @@ -2160,7 +2194,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { // these pointers, causing the tip and the tail to point to the same entry. func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { var cd *CommitDiff - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2199,7 +2233,7 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { // updates that still need to be signed for. func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2233,7 +2267,7 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { // updates that the remote still needs to sign for. func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2277,7 +2311,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { c.RemoteNextRevocation = revKey - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2318,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, var newRemoteCommit *ChannelCommitment - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2493,7 +2527,7 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { defer c.RUnlock() var fwdPkgs []*FwdPkg - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { var err error fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) return err @@ -2513,7 +2547,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckAddHtlcs(tx, addRefs...) }, func() {}) } @@ -2526,7 +2560,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckSettleFails(tx, settleFailRefs...) }, func() {}) } @@ -2537,7 +2571,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }, func() {}) } @@ -2551,7 +2585,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { for _, height := range heights { err := c.Packager.RemovePkg(tx, height) if err != nil { @@ -2579,7 +2613,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { } var commit ChannelCommitment - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2626,7 +2660,7 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { defer c.RUnlock() var height uint64 - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. chanBucket, err := fetchChanBucket( @@ -2663,7 +2697,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e defer c.RUnlock() var commit ChannelCommitment - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2821,7 +2855,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { openChanBucket := tx.ReadWriteBucket(openChannelBucket) if openChanBucket == nil { return ErrNoChanDBExists @@ -3033,7 +3067,7 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { // latest fully committed state is returned. The first commitment returned is // the local commitment, and the second returned is the remote commitment. func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -3055,7 +3089,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen // acting on a possible contract breach to ensure, that the caller has the most // up to date information required to deliver justice. func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index ad1b3c07c..044308f88 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -183,7 +183,7 @@ var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. -func createTestChannel(t *testing.T, cdb *DB, +func createTestChannel(t *testing.T, cdb *ChannelStateDB, opts ...testChannelOption) *OpenChannel { // Create a default set of parameters. @@ -221,7 +221,7 @@ func createTestChannel(t *testing.T, cdb *DB, return params.channel } -func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { +func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { // Simulate 1000 channel updates. producer, err := shachain.NewRevocationProducerFromBytes(key[:]) if err != nil { @@ -359,12 +359,14 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, with additional htlcs on the local // and remote commitment. localHtlcs := []HTLC{ @@ -508,12 +510,14 @@ func TestOptionalShutdown(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a channel with upfront scripts set as // specified in the test. state := createTestChannel( @@ -565,12 +569,14 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func TestChannelStateTransition(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a minimal channel, then perform a full sync in order to // persist the data. channel := createTestChannel(t, cdb) @@ -842,7 +848,7 @@ func TestChannelStateTransition(t *testing.T) { } // At this point, we should have 2 forwarding packages added. - fwdPkgs := loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") // Now attempt to delete the channel from the database. @@ -877,19 +883,21 @@ func TestChannelStateTransition(t *testing.T) { } // All forwarding packages of this channel has been deleted too. - fwdPkgs = loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) require.Empty(t, fwdPkgs, "no forwarding packages should exist") } func TestFetchPendingChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that was broadcast at height 99. const broadcastHeight = 99 createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) @@ -963,12 +971,14 @@ func TestFetchPendingChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel in the database. state := createTestChannel(t, cdb, openChannelOption()) @@ -1054,18 +1064,20 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. - db, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + channels := make([]*OpenChannel, numChannels) for i := 0; i < numChannels; i++ { // Create a pending channel in the database at the broadcast // height. channels[i] = createTestChannel( - t, db, pendingHeightOption(broadcastHeight), + t, cdb, pendingHeightOption(broadcastHeight), ) } @@ -1116,7 +1128,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // Now, we'll fetch all the channels waiting to be closed from the // database. We should expect to see both channels above, even if any of // them haven't had their funding transaction confirm on-chain. - waitingCloseChannels, err := db.FetchWaitingCloseChannels() + waitingCloseChannels, err := cdb.FetchWaitingCloseChannels() if err != nil { t.Fatalf("unable to fetch all waiting close channels: %v", err) } @@ -1169,12 +1181,14 @@ func TestFetchWaitingCloseChannels(t *testing.T) { func TestRefreshShortChanID(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a test channel. state := createTestChannel(t, cdb) @@ -1317,13 +1331,15 @@ func TestCloseInitiator(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1362,13 +1378,15 @@ func TestCloseInitiator(t *testing.T) { // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1427,7 +1445,7 @@ func TestBalanceAtHeight(t *testing.T) { putRevokedState := func(c *OpenChannel, height uint64, local, remote lnwire.MilliSatoshi) error { - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, @@ -1508,13 +1526,15 @@ func TestBalanceAtHeight(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create options to set the heights and balances of // our local and remote commitments. localCommitOpt := channelCommitmentOption( diff --git a/channeldb/db.go b/channeldb/db.go index 57ebfdb24..17275c292 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -23,6 +23,7 @@ import ( "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) const ( @@ -209,6 +210,11 @@ var ( // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. byteOrder = binary.BigEndian + + // channelOpeningStateBucket is the database bucket used to store the + // channelOpeningState for each channel that is currently in the process + // of being opened. + channelOpeningStateBucket = []byte("channelOpeningState") ) // DB is the primary datastore for the lnd daemon. The database stores @@ -217,6 +223,9 @@ var ( type DB struct { kvdb.Backend + // channelStateDB separates all DB operations on channel state. + channelStateDB *ChannelStateDB + dbPath string graph *ChannelGraph clock clock.Clock @@ -265,13 +274,27 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, chanDB := &DB{ Backend: backend, - clock: opts.clock, - dryRun: opts.dryRun, + channelStateDB: &ChannelStateDB{ + linkNodeDB: &LinkNodeDB{ + backend: backend, + }, + backend: backend, + }, + clock: opts.clock, + dryRun: opts.dryRun, } - chanDB.graph = newChannelGraph( - chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, + + // Set the parent pointer (only used in tests). + chanDB.channelStateDB.parent = chanDB + + var err error + chanDB.graph, err = NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, ) + if err != nil { + return nil, err + } // Synchronize the version of database and apply migrations if needed. if err := chanDB.syncVersions(dbVersions); err != nil { @@ -287,7 +310,7 @@ func (d *DB) Path() string { return d.dbPath } -var topLevelBuckets = [][]byte{ +var dbTopLevelBuckets = [][]byte{ openChannelBucket, closedChannelBucket, forwardingLogBucket, @@ -298,10 +321,6 @@ var topLevelBuckets = [][]byte{ paymentsIndexBucket, peersBucket, nodeInfoBucket, - nodeBucket, - edgeBucket, - edgeIndexBucket, - graphMetaBucket, metaBucket, closeSummaryBucket, outpointBucket, @@ -312,7 +331,7 @@ var topLevelBuckets = [][]byte{ // operation is fully atomic. func (d *DB) Wipe() error { err := kvdb.Update(d, func(tx kvdb.RwTx) error { - for _, tlb := range topLevelBuckets { + for _, tlb := range dbTopLevelBuckets { err := tx.DeleteTopLevelBucket(tlb) if err != nil && err != kvdb.ErrBucketNotFound { return err @@ -327,10 +346,10 @@ func (d *DB) Wipe() error { return initChannelDB(d.Backend) } -// createChannelDB creates and initializes a fresh version of channeldb. In -// the case that the target path has not yet been created or doesn't yet exist, -// then the path is created. Additionally, all required top-level buckets used -// within the database are created. +// initChannelDB creates and initializes a fresh version of channeldb. In the +// case that the target path has not yet been created or doesn't yet exist, then +// the path is created. Additionally, all required top-level buckets used within +// the database are created. func initChannelDB(db kvdb.Backend) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error { meta := &Meta{} @@ -340,42 +359,12 @@ func initChannelDB(db kvdb.Backend) error { return nil } - for _, tlb := range topLevelBuckets { + for _, tlb := range dbTopLevelBuckets { if _, err := tx.CreateTopLevelBucket(tlb); err != nil { return err } } - nodes := tx.ReadWriteBucket(nodeBucket) - _, err = nodes.CreateBucket(aliasIndexBucket) - if err != nil { - return err - } - _, err = nodes.CreateBucket(nodeUpdateIndexBucket) - if err != nil { - return err - } - - edges := tx.ReadWriteBucket(edgeBucket) - if _, err := edges.CreateBucket(edgeIndexBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(channelPointBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(zombieBucket); err != nil { - return err - } - - graphMeta := tx.ReadWriteBucket(graphMetaBucket) - _, err = graphMeta.CreateBucket(pruneLogBucket) - if err != nil { - return err - } - meta.DbVersionNumber = getLatestDBVersion(dbVersions) return putMeta(meta, tx) }, func() {}) @@ -397,15 +386,45 @@ func fileExists(path string) bool { return true } +// ChannelStateDB is a database that keeps track of all channel state. +type ChannelStateDB struct { + // linkNodeDB separates all DB operations on LinkNodes. + linkNodeDB *LinkNodeDB + + // parent holds a pointer to the "main" channeldb.DB object. This is + // only used for testing and should never be used in production code. + // For testing use the ChannelStateDB.GetParentDB() function to retrieve + // this pointer. + parent *DB + + // backend points to the actual backend holding the channel state + // database. This may be a real backend or a cache middleware. + backend kvdb.Backend +} + +// GetParentDB returns the "main" channeldb.DB object that is the owner of this +// ChannelStateDB instance. Use this function only in tests where passing around +// pointers makes testing less readable. Never to be used in production code! +func (c *ChannelStateDB) GetParentDB() *DB { + return c.parent +} + +// LinkNodeDB returns the current instance of the link node database. +func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB { + return c.linkNodeDB +} + // FetchOpenChannels starts a new database transaction and returns all stored // currently active/open channels associated with the target nodeID. In the case // that no active channels are known to have been created with this node, then a // zero-length slice is returned. -func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { var err error - channels, err = d.fetchOpenChannels(tx, nodeID) + channels, err = c.fetchOpenChannels(tx, nodeID) return err }, func() { channels = nil @@ -418,7 +437,7 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) // stored currently active/open channels associated with the target nodeID. In // the case that no active channels are known to have been created with this // node, then a zero-length slice is returned. -func (d *DB) fetchOpenChannels(tx kvdb.RTx, +func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx, nodeID *btcec.PublicKey) ([]*OpenChannel, error) { // Get the bucket dedicated to storing the metadata for open channels. @@ -454,7 +473,7 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // Finally, we both of the necessary buckets retrieved, fetch // all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) + nodeChannels, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read channel for "+ "chain_hash=%x, node_key=%x: %v", @@ -471,7 +490,8 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // fetchNodeChannels retrieves all active channels from the target chainBucket // which is under a node's dedicated channel bucket. This function is typically // used to fetch all the active channels related to a particular node. -func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) { +func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( + []*OpenChannel, error) { var channels []*OpenChannel @@ -497,7 +517,7 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) return fmt.Errorf("unable to read channel data for "+ "chan_point=%v: %v", outPoint, err) } - oChannel.Db = d + oChannel.Db = c channels = append(channels, oChannel) @@ -514,8 +534,8 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) // point. If the channel cannot be found, then an error will be returned. // Optionally an existing db tx can be supplied. Optionally an existing db tx // can be supplied. -func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, - error) { +func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( + *OpenChannel, error) { var ( targetChan *OpenChannel @@ -591,7 +611,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, } targetChan = channel - targetChan.Db = d + targetChan.Db = c return nil }) @@ -600,7 +620,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, var err error if tx == nil { - err = kvdb.View(d, chanScan, func() {}) + err = kvdb.View(c.backend, chanScan, func() {}) } else { err = chanScan(tx) } @@ -620,16 +640,16 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, // FetchAllChannels attempts to retrieve all open channels currently stored // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. -func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - return fetchChannels(d) +func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) { + return fetchChannels(c) } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. -func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) { return fetchChannels( - d, + c, pendingChannelFilter(false), waitingCloseFilter(false), ) @@ -638,8 +658,8 @@ func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. -func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, +func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) { + return fetchChannels(c, pendingChannelFilter(true), waitingCloseFilter(false), ) @@ -649,9 +669,9 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { // but are now waiting for a closing transaction to be confirmed. // // NOTE: This includes channels that are also pending to be opened. -func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { return fetchChannels( - d, waitingCloseFilter(true), + c, waitingCloseFilter(true), ) } @@ -692,10 +712,12 @@ func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { // which have a true value returned for *all* of the filters will be returned. // If no filters are provided, every channel in the open channels bucket will // be returned. -func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { +func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. openChanBucket := tx.ReadBucket(openChannelBucket) @@ -737,7 +759,7 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error "bucket for chain=%x", chainHash[:]) } - nodeChans, err := d.fetchNodeChannels(chainBucket) + nodeChans, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read "+ "channel for chain_hash=%x, "+ @@ -786,10 +808,12 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error // it becomes fully closed after a single confirmation. When a channel was // forcibly closed, it will become fully closed after _all_ the pending funds // (if any) have been swept. -func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) ( + []*ChannelCloseSummary, error) { + var chanSummaries []*ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrNoClosedChannels @@ -827,9 +851,11 @@ var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary // FetchClosedChannel queries for a channel close summary using the channel // point of the channel in question. -func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( + *ChannelCloseSummary, error) { + var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -861,11 +887,11 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er // FetchClosedChannelForID queries for a channel close summary using the // channel ID of the channel in question. -func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( +func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( *ChannelCloseSummary, error) { var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -914,8 +940,12 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // cooperatively closed and it's reached a single confirmation, or after all // the pending funds in a channel that has been forcibly closed have been // swept. -func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { + var ( + openChannels []*OpenChannel + pruneLinkNode *btcec.PublicKey + ) + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { var b bytes.Buffer if err := writeOutpoint(&b, chanPoint); err != nil { return err @@ -961,19 +991,35 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // other open channels with this peer. If we don't we'll // garbage collect it to ensure we don't establish persistent // connections to peers without open channels. - return d.pruneLinkNode(tx, chanSummary.RemotePub) - }, func() {}) + pruneLinkNode = chanSummary.RemotePub + openChannels, err = c.fetchOpenChannels( + tx, pruneLinkNode, + ) + if err != nil { + return fmt.Errorf("unable to fetch open channels for "+ + "peer %x: %v", + pruneLinkNode.SerializeCompressed(), err) + } + + return nil + }, func() { + openChannels = nil + pruneLinkNode = nil + }) + if err != nil { + return err + } + + // Decide whether we want to remove the link node, based upon the number + // of still open channels. + return c.pruneLinkNode(openChannels, pruneLinkNode) } // pruneLinkNode determines whether we should garbage collect a link node from // the database due to no longer having any open channels with it. If there are // any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error { - openChannels, err := d.fetchOpenChannels(tx, remotePub) - if err != nil { - return fmt.Errorf("unable to fetch open channels for peer %x: "+ - "%v", remotePub.SerializeCompressed(), err) - } +func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel, + remotePub *btcec.PublicKey) error { if len(openChannels) > 0 { return nil @@ -982,27 +1028,42 @@ func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error { log.Infof("Pruning link node %x with zero open channels from database", remotePub.SerializeCompressed()) - return d.deleteLinkNode(tx, remotePub) + return c.linkNodeDB.DeleteLinkNode(remotePub) } // PruneLinkNodes attempts to prune all link nodes found within the databse with // whom we no longer have any open channels with. -func (d *DB) PruneLinkNodes() error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { - linkNodes, err := d.fetchAllLinkNodes(tx) +func (c *ChannelStateDB) PruneLinkNodes() error { + allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes() + if err != nil { + return err + } + + for _, linkNode := range allLinkNodes { + var ( + openChannels []*OpenChannel + linkNode = linkNode + ) + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { + var err error + openChannels, err = c.fetchOpenChannels( + tx, linkNode.IdentityPub, + ) + return err + }, func() { + openChannels = nil + }) if err != nil { return err } - for _, linkNode := range linkNodes { - err := d.pruneLinkNode(tx, linkNode.IdentityPub) - if err != nil { - return err - } + err = c.pruneLinkNode(openChannels, linkNode.IdentityPub) + if err != nil { + return err } + } - return nil - }, func() {}) + return nil } // ChannelShell is a shell of a channel that is meant to be used for channel @@ -1024,8 +1085,8 @@ type ChannelShell struct { // addresses, and finally create an edge within the graph for the channel as // well. This method is idempotent, so repeated calls with the same set of // channel shells won't modify the database after the initial call. -func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { - err := kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error { + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { for _, channelShell := range channelShells { channel := channelShell.Chan @@ -1039,7 +1100,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // and link node for this channel. If the channel // already exists, then in order to ensure this method // is idempotent, we'll continue to the next step. - channel.Db = d + channel.Db = c err := syncNewChannel( tx, channel, channelShell.NodeAddrs, ) @@ -1059,41 +1120,28 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - var ( - linkNode *LinkNode - graphNode LightningNode - ) +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, + error) { - dbErr := kvdb.View(d, func(tx kvdb.RTx) error { - var err error + linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) + if err != nil { + return nil, err + } - linkNode, err = fetchLinkNode(tx, nodePub) - if err != nil { - return err - } - - // We'll also query the graph for this peer to see if they have - // any addresses that we don't currently have stored within the - // link node database. - nodes := tx.ReadBucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - compressedPubKey := nodePub.SerializeCompressed() - graphNode, err = fetchLightningNode(nodes, compressedPubKey) - if err != nil && err != ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still - // have the link node data. - return err - } - - return nil - }, func() { - linkNode = nil - }) - if dbErr != nil { - return nil, dbErr + // We'll also query the graph for this peer to see if they have any + // addresses that we don't currently have stored within the link node + // database. + pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) + if err != nil { + return nil, err + } + graphNode, err := d.graph.FetchLightningNode(pubKey) + if err != nil && err != ErrGraphNodeNotFound { + return nil, err + } else if err == ErrGraphNodeNotFound { + // If the node isn't found, then that's OK, as we still have the + // link node data. But any other error needs to be returned. + graphNode = &LightningNode{} } // Now that we have both sources of addrs for this node, we'll use a @@ -1118,16 +1166,18 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { // database. If the channel was already removed (has a closed channel entry), // then we'll return a nil error. Otherwise, we'll insert a new close summary // into the database. -func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { +func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, + bestHeight uint32) error { + // With the chanPoint constructed, we'll attempt to find the target // channel in the database. If we can't find the channel, then we'll // return the error back to the caller. - dbChan, err := d.FetchChannel(nil, *chanPoint) + dbChan, err := c.FetchChannel(nil, *chanPoint) switch { // If the channel wasn't found, then it's possible that it was already // abandoned from the database. case err == ErrChannelNotFound: - _, closedErr := d.FetchClosedChannel(chanPoint) + _, closedErr := c.FetchClosedChannel(chanPoint) if closedErr != nil { return closedErr } @@ -1163,6 +1213,58 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator) } +// SaveChannelOpeningState saves the serialized channel state for the provided +// chanPoint to the channelOpeningStateBucket. +func (c *ChannelStateDB) SaveChannelOpeningState(outPoint, + serializedState []byte) error { + + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { + bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) + if err != nil { + return err + } + + return bucket.Put(outPoint, serializedState) + }, func() {}) +} + +// GetChannelOpeningState fetches the serialized channel state for the provided +// outPoint from the database, or returns ErrChannelNotFound if the channel +// is not found. +func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { + var serializedState []byte + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { + bucket := tx.ReadBucket(channelOpeningStateBucket) + if bucket == nil { + // If the bucket does not exist, it means we never added + // a channel to the db, so return ErrChannelNotFound. + return ErrChannelNotFound + } + + serializedState = bucket.Get(outPoint) + if serializedState == nil { + return ErrChannelNotFound + } + + return nil + }, func() { + serializedState = nil + }) + return serializedState, err +} + +// DeleteChannelOpeningState removes any state for outPoint from the database. +func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket(channelOpeningStateBucket) + if bucket == nil { + return ErrChannelNotFound + } + + return bucket.Delete(outPoint) + }, func() {}) +} + // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. @@ -1236,11 +1338,17 @@ func (d *DB) syncVersions(versions []version) error { }, func() {}) } -// ChannelGraph returns a new instance of the directed channel graph. +// ChannelGraph returns the current instance of the directed channel graph. func (d *DB) ChannelGraph() *ChannelGraph { return d.graph } +// ChannelStateDB returns the sub database that is concerned with the channel +// state. +func (d *DB) ChannelStateDB() *ChannelStateDB { + return d.channelStateDB +} + func getLatestDBVersion(versions []version) uint32 { return versions[len(versions)-1].number } @@ -1290,9 +1398,11 @@ func fetchHistoricalChanBucket(tx kvdb.RTx, // FetchHistoricalChannel fetches open channel data from the historical channel // bucket. -func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) { +func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) ( + *OpenChannel, error) { + var channel *OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchHistoricalChanBucket(tx, outPoint) if err != nil { return err @@ -1300,7 +1410,7 @@ func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, erro channel, err = fetchOpenChannel(chanBucket, outPoint) - channel.Db = d + channel.Db = c return err }, func() { channel = nil diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 04c0dcd41..5731c03a8 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -87,15 +87,18 @@ func TestWipe(t *testing.T) { } defer cleanup() - cdb, err := CreateWithBackend(backend) + fullDB, err := CreateWithBackend(backend) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } - defer cdb.Close() + defer fullDB.Close() - if err := cdb.Wipe(); err != nil { + if err := fullDB.Wipe(); err != nil { t.Fatalf("unable to wipe channeldb: %v", err) } + + cdb := fullDB.ChannelStateDB() + // Check correct errors are returned openChannels, err := cdb.FetchAllOpenChannels() require.NoError(t, err, "fetching open channels") @@ -113,12 +116,14 @@ func TestFetchClosedChannelForID(t *testing.T) { const numChans = 101 - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, that we will mutate the index of the // funding point. state := createTestChannelState(t, cdb) @@ -184,18 +189,18 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestAddrsForNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - graph := cdb.ChannelGraph() + graph := fullDB.ChannelGraph() // We'll make a test vertex to insert into the database, as the source // node, but this node will only have half the number of addresses it // usually does. - testNode, err := createTestVertex(cdb) + testNode, err := createTestVertex(fullDB) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -210,8 +215,9 @@ func TestAddrsForNode(t *testing.T) { if err != nil { t.Fatalf("unable to recv node pub: %v", err) } - linkNode := cdb.NewLinkNode( - wire.MainNet, nodePub, anotherAddr, + linkNode := NewLinkNode( + fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub, + anotherAddr, ) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to sync link node: %v", err) @@ -219,7 +225,7 @@ func TestAddrsForNode(t *testing.T) { // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - nodeAddrs, err := cdb.AddrsForNode(nodePub) + nodeAddrs, err := fullDB.AddrsForNode(nodePub) if err != nil { t.Fatalf("unable to obtain node addrs: %v", err) } @@ -245,12 +251,14 @@ func TestAddrsForNode(t *testing.T) { func TestFetchChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channelState := createTestChannel(t, cdb, openChannelOption()) @@ -349,12 +357,14 @@ func genRandomChannelShell() (*ChannelShell, error) { func TestRestoreChannelShells(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First, we'll make our channel shell, it will only have the minimal // amount of information required for us to initiate the data loss // protection feature. @@ -423,7 +433,9 @@ func TestRestoreChannelShells(t *testing.T) { // We should also be able to find the link node that was inserted by // its public key. - linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) + linkNode, err := fullDB.channelStateDB.linkNodeDB.FetchLinkNode( + channelShell.Chan.IdentityPub, + ) if err != nil { t.Fatalf("unable to fetch link node: %v", err) } @@ -443,12 +455,14 @@ func TestRestoreChannelShells(t *testing.T) { func TestAbandonChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // If we attempt to abandon the state of a channel that doesn't exist // in the open or closed channel bucket, then we should receive an // error. @@ -616,13 +630,15 @@ func TestFetchChannels(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test "+ "database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that is not awaiting close. createTestChannel( t, cdb, channelIDOption(pendingChan), @@ -685,12 +701,14 @@ func TestFetchChannels(t *testing.T) { // TestFetchHistoricalChannel tests lookup of historical channels. func TestFetchHistoricalChannel(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a an open channel in the database. channel := createTestChannel(t, cdb, openChannelOption()) diff --git a/channeldb/graph.go b/channeldb/graph.go index 678b7ac06..68f4fb537 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -174,39 +174,132 @@ const ( // independently. Edge removal results in the deletion of all edge information // for that edge. type ChannelGraph struct { - db *DB + db kvdb.Backend cacheMu sync.RWMutex rejectCache *rejectCache chanCache *channelCache + graphCache *GraphCache chanScheduler batch.Scheduler nodeScheduler batch.Scheduler } -// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The +// NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. -func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int, - batchCommitInterval time.Duration) *ChannelGraph { +func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, + batchCommitInterval time.Duration, + preAllocCacheNumNodes int) (*ChannelGraph, error) { + + if err := initChannelGraph(db); err != nil { + return nil, err + } g := &ChannelGraph{ db: db, rejectCache: newRejectCache(rejectCacheSize), chanCache: newChannelCache(chanCacheSize), + graphCache: NewGraphCache(preAllocCacheNumNodes), } g.chanScheduler = batch.NewTimeScheduler( - db.Backend, &g.cacheMu, batchCommitInterval, + db, &g.cacheMu, batchCommitInterval, ) g.nodeScheduler = batch.NewTimeScheduler( - db.Backend, nil, batchCommitInterval, + db, nil, batchCommitInterval, ) - return g + startTime := time.Now() + log.Debugf("Populating in-memory channel graph, this might take a " + + "while...") + err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error { + return g.graphCache.AddNode(tx, node) + }) + if err != nil { + return nil, err + } + + log.Debugf("Finished populating in-memory channel graph (took %v, %s)", + time.Since(startTime), g.graphCache.Stats()) + + return g, nil } -// Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() *DB { - return c.db +var graphTopLevelBuckets = [][]byte{ + nodeBucket, + edgeBucket, + edgeIndexBucket, + graphMetaBucket, +} + +// Wipe completely deletes all saved state within all used buckets within the +// database. The deletion is done in a single transaction, therefore this +// operation is fully atomic. +func (c *ChannelGraph) Wipe() error { + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + for _, tlb := range graphTopLevelBuckets { + err := tx.DeleteTopLevelBucket(tlb) + if err != nil && err != kvdb.ErrBucketNotFound { + return err + } + } + return nil + }, func() {}) + if err != nil { + return err + } + + return initChannelGraph(c.db) +} + +// createChannelDB creates and initializes a fresh version of channeldb. In +// the case that the target path has not yet been created or doesn't yet exist, +// then the path is created. Additionally, all required top-level buckets used +// within the database are created. +func initChannelGraph(db kvdb.Backend) error { + err := kvdb.Update(db, func(tx kvdb.RwTx) error { + for _, tlb := range graphTopLevelBuckets { + if _, err := tx.CreateTopLevelBucket(tlb); err != nil { + return err + } + } + + nodes := tx.ReadWriteBucket(nodeBucket) + _, err := nodes.CreateBucketIfNotExists(aliasIndexBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucketIfNotExists(nodeUpdateIndexBucket) + if err != nil { + return err + } + + edges := tx.ReadWriteBucket(edgeBucket) + _, err = edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + graphMeta := tx.ReadWriteBucket(graphMetaBucket) + _, err = graphMeta.CreateBucketIfNotExists(pruneLogBucket) + return err + }, func() {}) + if err != nil { + return fmt.Errorf("unable to create new channel graph: %v", err) + } + + return nil } // ForEachChannel iterates through all the channel edges stored within the @@ -218,7 +311,9 @@ func (c *ChannelGraph) Database() *DB { // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates return kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -270,23 +365,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli // ForEachNodeChannel iterates through all channels of a given node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { - db := c.db + return c.graphCache.ForEachChannel(node, cb) +} - return nodeTraversal(tx, nodePub, db, cb) +// FetchNodeFeatures returns the features of a given node. +func (c *ChannelGraph) FetchNodeFeatures( + node route.Vertex) (*lnwire.FeatureVector, error) { + + return c.graphCache.GetFeatures(node), nil } // DisabledChannelIDs returns the channel ids of disabled channels. @@ -374,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro return kvdb.View(c.db, traversal, func() {}) } +// ForEachNodeCacheable iterates through all the stored vertices/nodes in the +// graph, executing the passed callback with each node encountered. If the +// callback returns an error, then the transaction is aborted and the iteration +// stops early. +func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx, + GraphCacheNode) error) error { + + traversal := func(tx kvdb.RTx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.ReadBucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + cacheableNode := newGraphCacheNode(route.Vertex{}, nil) + return nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + nodeReader := bytes.NewReader(nodeBytes) + err := deserializeLightningNodeCacheable( + nodeReader, cacheableNode, + ) + if err != nil { + return err + } + + // Execute the callback, the transaction will abort if + // this returns an error. + return cb(tx, cacheableNode) + }) + } + + return kvdb.View(c.db, traversal, func() {}) +} + // SourceNode returns the source node of the graph. The source node is treated // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another @@ -465,6 +600,13 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, r := &batch.Request{ Update: func(tx kvdb.RwTx) error { + cNode := newGraphCacheNode( + node.PubKeyBytes, node.Features, + ) + if err := c.graphCache.AddNode(tx, cNode); err != nil { + return err + } + return addLightningNode(tx, node) }, } @@ -543,6 +685,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error { return ErrGraphNodeNotFound } + c.graphCache.RemoveNode(nodePub) + return c.deleteLightningNode(nodes, nodePub[:]) }, func() {}) } @@ -669,6 +813,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error return ErrEdgeAlreadyExist } + c.graphCache.AddChannel(edge, nil, nil) + // Before we insert the channel into the database, we'll ensure that // both nodes already exist in the channel graph. If either node // doesn't, then we'll insert a "shell" node that just includes its @@ -868,6 +1014,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { return ErrEdgeNotFound } + c.graphCache.UpdateChannel(edge) + return putChanEdgeInfo(edgeIndex, edge, chanKey) }, func() {}) } @@ -953,7 +1101,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // will be returned if that outpoint isn't known to be // a channel. If no error is returned, then a channel // was successfully pruned. - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, chanID, false, false, ) @@ -1004,6 +1152,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, c.chanCache.remove(channel.ChannelID) } + log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats()) + return chansClosed, nil } @@ -1104,6 +1254,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, continue } + c.graphCache.RemoveNode(nodePubKey) + // If we reach this point, then there are no longer any edges // that connect this node, so we can delete it. if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { @@ -1202,7 +1354,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf } for _, k := range keys { - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, k, false, false, ) @@ -1310,7 +1462,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { // true, then when we mark these edges as zombies, we'll set up the keys such // that we require the node that failed to send the fresh update to be the one // that resurrects the channel from its zombie state. -func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error { +func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, + chanIDs ...uint64) error { + // TODO(roasbeef): possibly delete from node bucket if node has no more // channels // TODO(roasbeef): don't delete both edges? @@ -1343,7 +1497,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u var rawChanID [8]byte for _, chanID := range chanIDs { byteOrder.PutUint64(rawChanID[:], chanID) - err := delChannelEdge( + err := c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, rawChanID[:], true, strictZombiePruning, ) @@ -1472,7 +1626,9 @@ type ChannelEdge struct { // ChanUpdatesInHorizon returns all the known channel edges which have at least // one edge that has an update timestamp within the specified horizon. -func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, + endTime time.Time) ([]ChannelEdge, error) { + // To ensure we don't return duplicate ChannelEdges, we'll use an // additional map to keep track of the edges already seen to prevent // re-adding it. @@ -1605,7 +1761,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha // update timestamp within the passed range. This method can be used by two // nodes to quickly determine if they have the same set of up to date node // announcements. -func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, + endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -1933,7 +2091,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, return nil } -func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, +func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) @@ -1941,6 +2099,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return err } + c.graphCache.RemoveChannel( + edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, + edgeInfo.ChannelID, + ) + // We'll also remove the entry in the edge update index bucket before // we delete the edges themselves so we can access their last update // times. @@ -2075,7 +2238,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy, }, Update: func(tx kvdb.RwTx) error { var err error - isUpdate1, err = updateEdgePolicy(tx, edge) + isUpdate1, err = updateEdgePolicy( + tx, edge, c.graphCache, + ) // Silence ErrEdgeNotFound so that the batch can // succeed, but propagate the error via local state. @@ -2138,7 +2303,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) { // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { +func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, + graphCache *GraphCache) (bool, error) { + edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return false, ErrEdgeNotFound @@ -2186,6 +2353,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { return false, err } + var ( + fromNodePubKey route.Vertex + toNodePubKey route.Vertex + ) + copy(fromNodePubKey[:], fromNode) + copy(toNodePubKey[:], toNode) + graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) + return isUpdate1, nil } @@ -2232,7 +2407,7 @@ type LightningNode struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend // TODO(roasbeef): discovery will need storage to keep it's last IP // address and re-announce if interface changes? @@ -2356,17 +2531,11 @@ func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error) // FetchLightningNode attempts to look up a target node by its identity public // key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( +func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( *LightningNode, error) { var node *LightningNode - - fetchNode := func(tx kvdb.RTx) error { + err := kvdb.View(c.db, func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. nodes := tx.ReadBucket(nodeBucket) @@ -2393,14 +2562,9 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( node = &n return nil - } - - var err error - if tx == nil { - err = kvdb.View(c.db, fetchNode, func() {}) - } else { - err = fetchNode(tx) - } + }, func() { + node = nil + }) if err != nil { return nil, err } @@ -2408,6 +2572,52 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( return node, nil } +// graphCacheNode is a struct that wraps a LightningNode in a way that it can be +// cached in the graph cache. +type graphCacheNode struct { + pubKeyBytes route.Vertex + features *lnwire.FeatureVector + + nodeScratch [8]byte +} + +// newGraphCacheNode returns a new cache optimized node. +func newGraphCacheNode(pubKey route.Vertex, + features *lnwire.FeatureVector) *graphCacheNode { + + return &graphCacheNode{ + pubKeyBytes: pubKey, + features: features, + } +} + +// PubKey returns the node's public identity key. +func (n *graphCacheNode) PubKey() route.Vertex { + return n.pubKeyBytes +} + +// Features returns the node's features. +func (n *graphCacheNode) Features() *lnwire.FeatureVector { + return n.features +} + +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) +} + +var _ GraphCacheNode = (*graphCacheNode)(nil) + // HasLightningNode determines if the graph has a vertex identified by the // target node identity public key. If the node exists in the database, a // timestamp of when the data for the node was lasted updated is returned along @@ -2460,7 +2670,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. -func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, +func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { @@ -2548,7 +2758,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, // ForEachChannel iterates through all channels of this node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // @@ -2559,7 +2769,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] db := l.db @@ -2627,7 +2838,7 @@ type ChannelEdgeInfo struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend } // AddNodeKeys is a setter-like method that can be used to replace the set of @@ -2988,7 +3199,7 @@ type ChannelEdgePolicy struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend } // Signature is a channel announcement signature, which is needed for proper @@ -3406,7 +3617,7 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := kvdb.Batch(c.db.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -3417,6 +3628,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, "bucket: %w", err) } + c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID) + return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2) }) if err != nil { @@ -3471,6 +3684,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { c.rejectCache.remove(chanID) c.chanCache.remove(chanID) + // We need to add the channel back into our graph cache, otherwise we + // won't use it for path finding. + edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + if err != nil { + return err + } + for _, edgeInfo := range edgeInfos { + c.graphCache.AddChannel( + edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2, + ) + } + return nil } @@ -3696,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket, return deserializeLightningNode(nodeReader) } +func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error { + // Always populate a feature vector, even if we don't have a node + // announcement and short circuit below. + node.features = lnwire.EmptyFeatureVector() + + // Skip ahead: + // - LastUpdate (8 bytes) + if _, err := r.Read(node.nodeScratch[:]); err != nil { + return err + } + + if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil { + return err + } + + // Read the node announcement flag. + if _, err := r.Read(node.nodeScratch[:2]); err != nil { + return err + } + hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2]) + + // The rest of the data is optional, and will only be there if we got a + // node announcement for this node. + if hasNodeAnn == 0 { + return nil + } + + // We did get a node announcement for this node, so we'll have the rest + // of the data available. + var rgb uint8 + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + + if _, err := wire.ReadVarString(r, 0); err != nil { + return err + } + + return node.features.Decode(r) +} + func deserializeLightningNode(r io.Reader) (LightningNode, error) { var ( node LightningNode @@ -4102,7 +4374,7 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, nodes kvdb.RBucket, chanID []byte, - db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { + db kvdb.Backend) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { edgeInfo := edgeIndex.Get(chanID) if edgeInfo == nil { diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go new file mode 100644 index 000000000..bec44c3e5 --- /dev/null +++ b/channeldb/graph_cache.go @@ -0,0 +1,460 @@ +package channeldb + +import ( + "fmt" + "sync" + + "github.com/btcsuite/btcutil" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// GraphCacheNode is an interface for all the information the cache needs to know +// about a lightning node. +type GraphCacheNode interface { + // PubKey is the node's public identity key. + PubKey() route.Vertex + + // Features returns the node's p2p features. + Features() *lnwire.FeatureVector + + // ForEachChannel iterates through all channels of a given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. + ForEachChannel(kvdb.RTx, + func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error +} + +// CachedEdgePolicy is a struct that only caches the information of a +// ChannelEdgePolicy that we actually use for pathfinding and therefore need to +// store in the cache. +type CachedEdgePolicy struct { + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // MessageFlags is a bitfield which indicates the presence of optional + // fields (like max_htlc) in the policy. + MessageFlags lnwire.ChanUpdateMsgFlags + + // ChannelFlags is a bitfield which signals the capabilities of the + // channel as well as the directed edge this update applies to. + ChannelFlags lnwire.ChanUpdateChanFlags + + // TimeLockDelta is the number of blocks this node will subtract from + // the expiry of an incoming HTLC. This value expresses the time buffer + // the node would like to HTLC exchanges. + TimeLockDelta uint16 + + // MinHTLC is the smallest value HTLC this node will forward, expressed + // in millisatoshi. + MinHTLC lnwire.MilliSatoshi + + // MaxHTLC is the largest value HTLC this node will forward, expressed + // in millisatoshi. + MaxHTLC lnwire.MilliSatoshi + + // FeeBaseMSat is the base HTLC fee that will be charged for forwarding + // ANY HTLC, expressed in mSAT's. + FeeBaseMSat lnwire.MilliSatoshi + + // FeeProportionalMillionths is the rate that the node will charge for + // HTLCs for each millionth of a satoshi forwarded. + FeeProportionalMillionths lnwire.MilliSatoshi + + // ToNodePubKey is a function that returns the to node of a policy. + // Since we only ever store the inbound policy, this is always the node + // that we query the channels for in ForEachChannel(). Therefore, we can + // save a lot of space by not storing this information in the memory and + // instead just set this function when we copy the policy from cache in + // ForEachChannel(). + ToNodePubKey func() route.Vertex + + // ToNodeFeatures are the to node's features. They are never set while + // the edge is in the cache, only on the copy that is returned in + // ForEachChannel(). + ToNodeFeatures *lnwire.FeatureVector +} + +// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over +// the passed active payment channel. This value is currently computed as +// specified in BOLT07, but will likely change in the near future. +func (c *CachedEdgePolicy) ComputeFee( + amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *CachedEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + +// NewCachedPolicy turns a full policy into a minimal one that can be cached. +func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { + return &CachedEdgePolicy{ + ChannelID: policy.ChannelID, + MessageFlags: policy.MessageFlags, + ChannelFlags: policy.ChannelFlags, + TimeLockDelta: policy.TimeLockDelta, + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + FeeBaseMSat: policy.FeeBaseMSat, + FeeProportionalMillionths: policy.FeeProportionalMillionths, + } +} + +// DirectedChannel is a type that stores the channel information as seen from +// one side of the channel. +type DirectedChannel struct { + // ChannelID is the unique identifier of this channel. + ChannelID uint64 + + // IsNode1 indicates if this is the node with the smaller public key. + IsNode1 bool + + // OtherNode is the public key of the node on the other end of this + // channel. + OtherNode route.Vertex + + // Capacity is the announced capacity of this channel in satoshis. + Capacity btcutil.Amount + + // OutPolicySet is a boolean that indicates whether the node has an + // outgoing policy set. For pathfinding only the existence of the policy + // is important to know, not the actual content. + OutPolicySet bool + + // InPolicy is the incoming policy *from* the other node to this node. + // In path finding, we're walking backward from the destination to the + // source, so we're always interested in the edge that arrives to us + // from the other node. + InPolicy *CachedEdgePolicy +} + +// DeepCopy creates a deep copy of the channel, including the incoming policy. +func (c *DirectedChannel) DeepCopy() *DirectedChannel { + channelCopy := *c + + if channelCopy.InPolicy != nil { + inPolicyCopy := *channelCopy.InPolicy + channelCopy.InPolicy = &inPolicyCopy + + // The fields for the ToNode can be overwritten by the path + // finding algorithm, which is why we need a deep copy in the + // first place. So we always start out with nil values, just to + // be sure they don't contain any old data. + channelCopy.InPolicy.ToNodePubKey = nil + channelCopy.InPolicy.ToNodeFeatures = nil + } + + return &channelCopy +} + +// GraphCache is a type that holds a minimal set of information of the public +// channel graph that can be used for pathfinding. +type GraphCache struct { + nodeChannels map[route.Vertex]map[uint64]*DirectedChannel + nodeFeatures map[route.Vertex]*lnwire.FeatureVector + + mtx sync.RWMutex +} + +// NewGraphCache creates a new graphCache. +func NewGraphCache(preAllocNumNodes int) *GraphCache { + return &GraphCache{ + nodeChannels: make( + map[route.Vertex]map[uint64]*DirectedChannel, + // A channel connects two nodes, so we can look it up + // from both sides, meaning we get double the number of + // entries. + preAllocNumNodes*2, + ), + nodeFeatures: make( + map[route.Vertex]*lnwire.FeatureVector, + preAllocNumNodes, + ), + } +} + +// Stats returns statistics about the current cache size. +func (c *GraphCache) Stats() string { + c.mtx.RLock() + defer c.mtx.RUnlock() + + numChannels := 0 + for node := range c.nodeChannels { + numChannels += len(c.nodeChannels[node]) + } + return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+ + "num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels), + numChannels) +} + +// AddNode adds a graph node, including all the (directed) channels of that +// node. +func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { + nodePubKey := node.PubKey() + + // Only hold the lock for a short time. The `ForEachChannel()` below is + // possibly slow as it has to go to the backend, so we can unlock + // between the calls. And the AddChannel() method will acquire its own + // lock anyway. + c.mtx.Lock() + c.nodeFeatures[nodePubKey] = node.Features() + c.mtx.Unlock() + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + outPolicy *ChannelEdgePolicy, + inPolicy *ChannelEdgePolicy) error { + + c.AddChannel(info, outPolicy, inPolicy) + + return nil + }, + ) +} + +// AddChannel adds a non-directed channel, meaning that the order of policy 1 +// and policy 2 does not matter, the directionality is extracted from the info +// and policy flags automatically. The policy will be set as the outgoing policy +// on one node and the incoming policy on the peer's side. +func (c *GraphCache) AddChannel(info *ChannelEdgeInfo, + policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { + + if info == nil { + return + } + + if policy1 != nil && policy1.IsDisabled() && + policy2 != nil && policy2.IsDisabled() { + + return + } + + // Create the edge entry for both nodes. + c.mtx.Lock() + c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: true, + OtherNode: info.NodeKey2Bytes, + Capacity: info.Capacity, + }) + c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: false, + OtherNode: info.NodeKey1Bytes, + Capacity: info.Capacity, + }) + c.mtx.Unlock() + + // The policy's node is always the to_node. So if policy 1 has to_node + // of node 2 then we have the policy 1 as seen from node 1. + if policy1 != nil { + fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes + if policy1.Node.PubKeyBytes != info.NodeKey2Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy1, fromNode, toNode, isEdge1) + } + if policy2 != nil { + fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes + if policy2.Node.PubKeyBytes != info.NodeKey1Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy2, fromNode, toNode, isEdge1) + } +} + +// updateOrAddEdge makes sure the edge information for a node is either updated +// if it already exists or is added to that node's list of channels. +func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { + if len(c.nodeChannels[node]) == 0 { + c.nodeChannels[node] = make(map[uint64]*DirectedChannel) + } + + c.nodeChannels[node][edge.ChannelID] = edge +} + +// UpdatePolicy updates a single policy on both the from and to node. The order +// of the from and to node is not strictly important. But we assume that a +// channel edge was added beforehand so that the directed channel struct already +// exists in the cache. +func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, + toNode route.Vertex, edge1 bool) { + + c.mtx.Lock() + defer c.mtx.Unlock() + + updatePolicy := func(nodeKey route.Vertex) { + if len(c.nodeChannels[nodeKey]) == 0 { + return + } + + channel, ok := c.nodeChannels[nodeKey][policy.ChannelID] + if !ok { + return + } + + // Edge 1 is defined as the policy for the direction of node1 to + // node2. + switch { + // This is node 1, and it is edge 1, so this is the outgoing + // policy for node 1. + case channel.IsNode1 && edge1: + channel.OutPolicySet = true + + // This is node 2, and it is edge 2, so this is the outgoing + // policy for node 2. + case !channel.IsNode1 && !edge1: + channel.OutPolicySet = true + + // The other two cases left mean it's the inbound policy for the + // node. + default: + channel.InPolicy = NewCachedPolicy(policy) + } + } + + updatePolicy(fromNode) + updatePolicy(toNode) +} + +// RemoveNode completely removes a node and all its channels (including the +// peer's side). +func (c *GraphCache) RemoveNode(node route.Vertex) { + c.mtx.Lock() + defer c.mtx.Unlock() + + delete(c.nodeFeatures, node) + + // First remove all channels from the other nodes' lists. + for _, channel := range c.nodeChannels[node] { + c.removeChannelIfFound(channel.OtherNode, channel.ChannelID) + } + + // Then remove our whole node completely. + delete(c.nodeChannels, node) +} + +// RemoveChannel removes a single channel between two nodes. +func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Remove that one channel from both sides. + c.removeChannelIfFound(node1, chanID) + c.removeChannelIfFound(node2, chanID) +} + +// removeChannelIfFound removes a single channel from one side. +func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { + if len(c.nodeChannels[node]) == 0 { + return + } + + delete(c.nodeChannels[node], chanID) +} + +// UpdateChannel updates the channel edge information for a specific edge. We +// expect the edge to already exist and be known. If it does not yet exist, this +// call is a no-op. +func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) { + c.mtx.Lock() + defer c.mtx.Unlock() + + if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 || + len(c.nodeChannels[info.NodeKey2Bytes]) == 0 { + + return + } + + channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID] + if ok { + // We only expect to be called when the channel is already + // known. + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey2Bytes + } + + channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID] + if ok { + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey1Bytes + } +} + +// ForEachChannel invokes the given callback for each channel of the given node. +func (c *GraphCache) ForEachChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { + + c.mtx.RLock() + defer c.mtx.RUnlock() + + channels, ok := c.nodeChannels[node] + if !ok { + return nil + } + + features, ok := c.nodeFeatures[node] + if !ok { + log.Warnf("Node %v has no features defined, falling back to "+ + "default feature vector for path finding", node) + + features = lnwire.EmptyFeatureVector() + } + + toNodeCallback := func() route.Vertex { + return node + } + + for _, channel := range channels { + // We need to copy the channel and policy to avoid it being + // updated in the cache if the path finding algorithm sets + // fields on it (currently only the ToNodeFeatures of the + // policy). + channelCopy := channel.DeepCopy() + if channelCopy.InPolicy != nil { + channelCopy.InPolicy.ToNodePubKey = toNodeCallback + channelCopy.InPolicy.ToNodeFeatures = features + } + + if err := cb(channelCopy); err != nil { + return err + } + } + + return nil +} + +// GetFeatures returns the features of the node with the given ID. +func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector { + c.mtx.RLock() + defer c.mtx.RUnlock() + + features, ok := c.nodeFeatures[node] + if !ok || features == nil { + // The router expects the features to never be nil, so we return + // an empty feature set instead. + return lnwire.EmptyFeatureVector() + } + + return features +} diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go new file mode 100644 index 000000000..09cfbf237 --- /dev/null +++ b/channeldb/graph_cache_test.go @@ -0,0 +1,147 @@ +package channeldb + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +var ( + pubKey1Bytes, _ = hex.DecodeString( + "0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" + + "22c91d", + ) + pubKey2Bytes, _ = hex.DecodeString( + "038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" + + "f4484f", + ) + + pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes) + pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes) +) + +type node struct { + pubKey route.Vertex + features *lnwire.FeatureVector + + edgeInfos []*ChannelEdgeInfo + outPolicies []*ChannelEdgePolicy + inPolicies []*ChannelEdgePolicy +} + +func (n *node) PubKey() route.Vertex { + return n.pubKey +} +func (n *node) Features() *lnwire.FeatureVector { + return n.features +} + +func (n *node) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + for idx := range n.edgeInfos { + err := cb( + tx, n.edgeInfos[idx], n.outPolicies[idx], + n.inPolicies[idx], + ) + if err != nil { + return err + } + } + + return nil +} + +// TestGraphCacheAddNode tests that a channel going from node A to node B can be +// cached correctly, independent of the direction we add the channel as. +func TestGraphCacheAddNode(t *testing.T) { + runTest := func(nodeA, nodeB route.Vertex) { + t.Helper() + + channelFlagA, channelFlagB := 0, 1 + if nodeA == pubKey2 { + channelFlagA, channelFlagB = 1, 0 + } + + outPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), + Node: &LightningNode{ + PubKeyBytes: nodeB, + Features: lnwire.EmptyFeatureVector(), + }, + } + inPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), + Node: &LightningNode{ + PubKeyBytes: nodeA, + Features: lnwire.EmptyFeatureVector(), + }, + } + node := &node{ + pubKey: nodeA, + features: lnwire.EmptyFeatureVector(), + edgeInfos: []*ChannelEdgeInfo{{ + ChannelID: 1000, + // Those are direction independent! + NodeKey1Bytes: pubKey1, + NodeKey2Bytes: pubKey2, + Capacity: 500, + }}, + outPolicies: []*ChannelEdgePolicy{outPolicy1}, + inPolicies: []*ChannelEdgePolicy{inPolicy1}, + } + cache := NewGraphCache(10) + require.NoError(t, cache.AddNode(nil, node)) + + var fromChannels, toChannels []*DirectedChannel + _ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error { + fromChannels = append(fromChannels, c) + return nil + }) + _ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error { + toChannels = append(toChannels, c) + return nil + }) + + require.Len(t, fromChannels, 1) + require.Len(t, toChannels, 1) + + require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy) + + require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) + } + + runTest(pubKey1, pubKey2) + runTest(pubKey2, pubKey1) +} + +func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy, + cached *CachedEdgePolicy) { + + require.Equal(t, original.ChannelID, cached.ChannelID) + require.Equal(t, original.MessageFlags, cached.MessageFlags) + require.Equal(t, original.ChannelFlags, cached.ChannelFlags) + require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta) + require.Equal(t, original.MinHTLC, cached.MinHTLC) + require.Equal(t, original.MaxHTLC, cached.MaxHTLC) + require.Equal(t, original.FeeBaseMSat, cached.FeeBaseMSat) + require.Equal( + t, original.FeeProportionalMillionths, + cached.FeeProportionalMillionths, + ) + require.Equal( + t, + route.Vertex(original.Node.PubKeyBytes), + cached.ToNodePubKey(), + ) + require.Equal(t, original.Node.Features, cached.ToNodeFeatures) +} diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index ccd5f379e..27b979842 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "image/color" + "io/ioutil" "math" "math/big" prand "math/rand" "net" + "os" "reflect" "runtime" "sync" @@ -19,6 +21,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -40,12 +43,57 @@ var ( _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) - testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + testFeatures = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + lnwire.Features, + ) testPub = route.Vertex{2, 202, 4} ) -func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) { +// MakeTestGraph creates a new instance of the ChannelGraph for testing +// purposes. A callback which cleans up the created temporary directories is +// also returned and intended to be executed after the test completes. +func MakeTestGraph(modifiers ...OptionModifier) (*ChannelGraph, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channelgraph") + if err != nil { + return nil, nil, err + } + + opts := DefaultOptions() + for _, modifier := range modifiers { + modifier(&opts) + } + + // Next, create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr") + if err != nil { + backendCleanup() + return nil, nil, err + } + + graph, err := NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, + ) + if err != nil { + backendCleanup() + _ = os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + _ = backend.Close() + backendCleanup() + _ = os.RemoveAll(tempDirName) + } + + return graph, cleanUp, nil +} + +func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNode, error) { updateTime := prand.Int63() pub := priv.PubKey().SerializeCompressed() @@ -64,7 +112,7 @@ func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) return n, nil } -func createTestVertex(db *DB) (*LightningNode, error) { +func createTestVertex(db kvdb.Backend) (*LightningNode, error) { priv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { return nil, err @@ -76,14 +124,12 @@ func createTestVertex(db *DB) (*LightningNode, error) { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. node := &LightningNode{ @@ -96,7 +142,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { Addresses: testAddrs, ExtraOpaqueData: []byte("extra new data"), PubKeyBytes: testPub, - db: db, + db: graph.db, } // First, insert the node into the graph DB. This should succeed @@ -104,10 +150,11 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, testFeatures) // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -128,10 +175,11 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node; %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -142,14 +190,12 @@ func TestNodeInsertionAndDeletion(t *testing.T) { func TestPartialNode(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We want to be able to insert nodes into the graph that only has the // PubKey set. node := &LightningNode{ @@ -160,10 +206,11 @@ func TestPartialNode(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, nil) // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -180,7 +227,7 @@ func TestPartialNode(t *testing.T) { HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: testPub, - db: db, + db: graph.db, } if err := compareNodes(node, dbNode); err != nil { @@ -192,10 +239,11 @@ func TestPartialNode(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node: %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -204,17 +252,15 @@ func TestPartialNode(t *testing.T) { func TestAliasLookup(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the alias index within the database, so first // create a new test node. - testNode, err := createTestVertex(db) + testNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -241,7 +287,7 @@ func TestAliasLookup(t *testing.T) { } // Ensure that looking up a non-existent alias results in an error. - node, err := createTestVertex(db) + node, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -258,17 +304,15 @@ func TestAliasLookup(t *testing.T) { func TestSourceNode(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() - defer cleanUp() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } - - graph := db.ChannelGraph() + defer cleanUp() // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. - testNode, err := createTestVertex(db) + testNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -299,21 +343,19 @@ func TestSourceNode(t *testing.T) { func TestEdgeInsertionDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -356,6 +398,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) // Ensure that both policies are returned as unknown (nil). _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) @@ -371,6 +414,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.DeleteChannelEdges(false, chanID); err != nil { t.Fatalf("unable to delete edge: %v", err) } + assertNoEdge(t, graph, chanID) // Ensure that any query attempts to lookup the delete channel edge are // properly deleted. @@ -434,14 +478,13 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -451,11 +494,11 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -511,6 +554,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo3); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo2) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // Call DisconnectBlockAtHeight, which should prune every channel // that has a funding height of 'height' or greater. @@ -518,6 +564,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err != nil { t.Fatalf("unable to prune %v", err) } + assertNoEdge(t, graph, edgeInfo.ChannelID) + assertNoEdge(t, graph, edgeInfo2.ChannelID) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // The two edges should have been removed. if len(removed) != 2 { @@ -641,7 +690,7 @@ func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, } } -func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, +func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) { var ( @@ -721,44 +770,46 @@ func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + assertNodeInCache(t, graph, node1, testFeatures) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node2, testFeatures) // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { t.Fatalf("expected ErrEdgeNotFound, got: %v", err) } + require.Len(t, graph.graphCache.nodeChannels, 0) // Add the edge info. if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo) chanID := edgeInfo.ChannelID outpoint := edgeInfo.ChannelPoint @@ -768,9 +819,11 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.UpdateEdgePolicy(edge1); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge1, true) if err := graph.UpdateEdgePolicy(edge2); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge2, false) // Check for existence of the edge within the database, it should be // found. @@ -825,13 +878,198 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } -func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { - update := prand.Int63() +func assertNodeInCache(t *testing.T, g *ChannelGraph, n *LightningNode, + expectedFeatures *lnwire.FeatureVector) { - return newEdgePolicy(chanID, op, db, update) + // Let's check the internal view first. + require.Equal( + t, expectedFeatures, g.graphCache.nodeFeatures[n.PubKeyBytes], + ) + + // The external view should reflect this as well. Except when we expect + // the features to be nil internally, we return an empty feature vector + // on the public interface instead. + if expectedFeatures == nil { + expectedFeatures = lnwire.EmptyFeatureVector() + } + features := g.graphCache.GetFeatures(n.PubKeyBytes) + require.Equal(t, expectedFeatures, features) } -func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, +func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { + _, ok := g.graphCache.nodeFeatures[n] + require.False(t, ok) + + _, ok = g.graphCache.nodeChannels[n] + require.False(t, ok) + + // We should get the default features for this node. + features := g.graphCache.GetFeatures(n) + require.Equal(t, lnwire.EmptyFeatureVector(), features) +} + +func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo) { + + // Let's check the internal view first. + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey2Bytes]) + + expectedNode1Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: true, + OtherNode: e.NodeKey2Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey1Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode1Channel, + g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID], + ) + + expectedNode2Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: false, + OtherNode: e.NodeKey1Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey2Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode2Channel, + g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID], + ) + + // The external view should reflect this as well. + var foundChannel *DirectedChannel + err := g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode1Channel, foundChannel) + + err = g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode2Channel, foundChannel) +} + +func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { + // Make sure no channel in the cache has the given channel ID. If there + // are no channels at all, that is fine as well. + for _, channels := range g.graphCache.nodeChannels { + for _, channel := range channels { + require.NotEqual(t, channel.ChannelID, chanID) + } + } +} + +func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + + // Check the internal state first. + c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.True(t, c1.OutPolicySet) + } else { + require.NotNil(t, c1.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c1.InPolicy.FeeProportionalMillionths, + ) + } + + c2, ok := g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.NotNil(t, c2.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c2.InPolicy.FeeProportionalMillionths, + ) + } else { + require.True(t, c2.OutPolicySet) + } + + // Now for both nodes make sure that the external view is also correct. + var ( + c1Ext *DirectedChannel + c2Ext *DirectedChannel + ) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + c1Ext = c + + return nil + }, + )) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + c2Ext = c + + return nil + }, + )) + + // Only compare the fields that are actually copied, then compare the + // values of the functions separately. + require.Equal(t, c1, c1Ext.DeepCopy()) + require.Equal(t, c2, c2Ext.DeepCopy()) + if policy1 { + require.Equal( + t, p.FeeProportionalMillionths, + c2Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey2Bytes), + c2Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c2Ext.InPolicy.ToNodeFeatures) + } else { + require.Equal( + t, p.FeeProportionalMillionths, + c1Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey1Bytes), + c1Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c1Ext.InPolicy.ToNodeFeatures) + } +} + +func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy { + update := prand.Int63() + + return newEdgePolicy(chanID, db, update) +} + +func newEdgePolicy(chanID uint64, db kvdb.Backend, updateTime int64) *ChannelEdgePolicy { return &ChannelEdgePolicy{ @@ -851,116 +1089,18 @@ func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, func TestGraphTraversal(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the - // graph. + // graph. And we'll create 5 channels between each node pair. const numNodes = 20 - nodes := make([]*LightningNode, numNodes) - nodeIndex := map[string]struct{}{} - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - nodes[i] = node - nodeIndex[node.Alias] = struct{}{} - } - - // Add each of the nodes into the graph, they should be inserted - // without error. - for _, node := range nodes { - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - - // Iterate over each node as returned by the graph, if all nodes are - // reached, then the map created above should be empty. - err = graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(nodeIndex) != 0 { - t.Fatalf("all nodes not reached within ForEach") - } - - // Determine which node is "smaller", we'll need this in order to - // properly create the edges for the graph. - var firstNode, secondNode *LightningNode - if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { - firstNode = nodes[0] - secondNode = nodes[1] - } else { - firstNode = nodes[0] - secondNode = nodes[1] - } - - // Create 5 channels between the first two nodes we generated above. const numChannels = 5 - chanIndex := map[uint64]struct{}{} - for i := 0; i < numChannels; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) - err := graph.AddChannelEdge(&edgeInfo) - if err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create and add an edge with random data that points from - // node1 -> node2. - edge := randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 0 - edge.Node = secondNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node2 -> node1 - // this time. - edge = randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 1 - edge.Node = firstNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - chanIndex[chanID] = struct{}{} - } + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have @@ -971,16 +1111,13 @@ func TestGraphTraversal(t *testing.T) { delete(chanIndex, ei.ChannelID) return nil }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(chanIndex) != 0 { - t.Fatalf("all edges not reached within ForEach") - } + require.NoError(t, err) + require.Len(t, chanIndex, 0) // Finally, we want to test the ability to iterate over all the // outgoing channels for a particular node. numNodeChans := 0 + firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo, outEdge, inEdge *ChannelEdgePolicy) error { @@ -1005,13 +1142,202 @@ func TestGraphTraversal(t *testing.T) { numNodeChans++ return nil }) + require.NoError(t, err) + require.Equal(t, numChannels, numNodeChans) +} + +// TestGraphTraversalCacheable tests that the memory optimized node traversal is +// working correctly. +func TestGraphTraversalCacheable(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() if err != nil { - t.Fatalf("for each failure: %v", err) + t.Fatalf("unable to make test database: %v", err) } - if numNodeChans != numChannels { - t.Fatalf("all edges for node not reached within ForEach: "+ - "expected %v, got %v", numChannels, numNodeChans) + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between the first two nodes. + const numNodes = 20 + const numChannels = 5 + chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels) + + // Create a map of all nodes with the iteration we know works (because + // it is tested in another test). + nodeMap := make(map[route.Vertex]struct{}) + err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error { + nodeMap[n.PubKeyBytes] = struct{}{} + + return nil + }) + require.NoError(t, err) + require.Len(t, nodeMap, numNodes) + + // Iterate through all the known channels within the graph DB by + // iterating over each node, once again if the map is empty that + // indicates that all edges have properly been reached. + err = graph.ForEachNodeCacheable( + func(tx kvdb.RTx, node GraphCacheNode) error { + delete(nodeMap, node.PubKey()) + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + delete(chanIndex, info.ChannelID) + return nil + }, + ) + }, + ) + require.NoError(t, err) + require.Len(t, nodeMap, 0) + require.Len(t, chanIndex, 0) +} + +func TestGraphCacheTraversal(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() + require.NoError(t, err) + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between each node pair. + const numNodes = 20 + const numChannels = 5 + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) + + // Iterate through all the known channels within the graph DB, once + // again if the map is empty that indicates that all edges have + // properly been reached. + numNodeChans := 0 + for _, node := range nodeList { + node := node + + err = graph.graphCache.ForEachChannel( + node.PubKeyBytes, func(d *DirectedChannel) error { + delete(chanIndex, d.ChannelID) + + if !d.OutPolicySet || d.InPolicy == nil { + return fmt.Errorf("channel policy not " + + "present") + } + + // The incoming edge should also indicate that + // it's pointing to the origin node. + inPolicyNodeKey := d.InPolicy.ToNodePubKey() + if !bytes.Equal( + inPolicyNodeKey[:], node.PubKeyBytes[:], + ) { + return fmt.Errorf("wrong outgoing edge") + } + + numNodeChans++ + + return nil + }, + ) + require.NoError(t, err) } + require.Len(t, chanIndex, 0) + + // We count the channels for both nodes, so there should be double the + // amount now. Except for the very last node, that doesn't have any + // channels to make the loop easier in fillTestGraph(). + require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) +} + +func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, + numChannels int) (map[uint64]struct{}, []*LightningNode) { + + nodes := make([]*LightningNode, numNodes) + nodeIndex := map[string]struct{}{} + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(graph.db) + require.NoError(t, err) + + nodes[i] = node + nodeIndex[node.Alias] = struct{}{} + } + + // Add each of the nodes into the graph, they should be inserted + // without error. + for _, node := range nodes { + require.NoError(t, graph.AddLightningNode(node)) + } + + // Iterate over each node as returned by the graph, if all nodes are + // reached, then the map created above should be empty. + err := graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }) + require.NoError(t, err) + require.Len(t, nodeIndex, 0) + + // Create a number of channels between each of the node pairs generated + // above. This will result in numChannels*(numNodes-1) channels. + chanIndex := map[uint64]struct{}{} + for n := 0; n < numNodes-1; n++ { + node1 := nodes[n] + node2 := nodes[n+1] + if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + node1, node2 = node2, node1 + } + + for i := 0; i < numChannels; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64((n << 8) + i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], node2.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], node2.PubKeyBytes[:]) + err := graph.AddChannelEdge(&edgeInfo) + require.NoError(t, err) + + // Create and add an edge with random data that points + // from node1 -> node2. + edge := randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 0 + edge.Node = node2 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + // Create another random edge that points from + // node2 -> node1 this time. + edge = randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 1 + edge.Node = node1 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + chanIndex[chanID] = struct{}{} + } + } + + return chanIndex, nodes } func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, @@ -1112,14 +1438,13 @@ func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoi func TestGraphPruning(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -1133,7 +1458,7 @@ func TestGraphPruning(t *testing.T) { const numNodes = 5 graphNodes := make([]*LightningNode, numNodes) for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) + node, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create node: %v", err) } @@ -1192,7 +1517,7 @@ func TestGraphPruning(t *testing.T) { // Create and add an edge with random data that points from // node_i -> node_i+1 - edge := randEdgePolicy(chanID, op, db) + edge := randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 0 edge.Node = graphNodes[i] edge.SigBytes = testSig.Serialize() @@ -1202,7 +1527,7 @@ func TestGraphPruning(t *testing.T) { // Create another random edge that points from node_i+1 -> // node_i this time. - edge = randEdgePolicy(chanID, op, db) + edge = randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 1 edge.Node = graphNodes[i] edge.SigBytes = testSig.Serialize() @@ -1320,14 +1645,12 @@ func TestGraphPruning(t *testing.T) { func TestHighestChanID(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we don't yet have any channels in the database, then we should // get a channel ID of zero if we ask for the highest channel ID. bestID, err := graph.HighestChanID() @@ -1341,11 +1664,11 @@ func TestHighestChanID(t *testing.T) { // Next, we'll insert two channels into the database, with each channel // connecting the same two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1397,14 +1720,12 @@ func TestHighestChanID(t *testing.T) { func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. chanUpdates, err := graph.ChanUpdatesInHorizon( @@ -1419,14 +1740,14 @@ func TestChanUpdatesInHorizon(t *testing.T) { } // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1441,12 +1762,6 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime := startTime edges := make([]ChannelEdge, 0, numChans) for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, ) @@ -1460,7 +1775,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime = endTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), + chanID.ToUint64(), graph.db, edge1UpdateTime.Unix(), ) edge1.ChannelFlags = 0 edge1.Node = node2 @@ -1470,7 +1785,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), + chanID.ToUint64(), graph.db, edge2UpdateTime.Unix(), ) edge2.ChannelFlags = 1 edge2.Node = node1 @@ -1573,14 +1888,12 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - startTime := time.Unix(1234, 0) endTime := startTime @@ -1602,7 +1915,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { const numNodes = 10 nodeAnns := make([]LightningNode, 0, numNodes) for i := 0; i < numNodes; i++ { - nodeAnn, err := createTestVertex(db) + nodeAnn, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } @@ -1696,14 +2009,12 @@ func TestNodeUpdatesInHorizon(t *testing.T) { func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we try to filter out a set of channel ID's before we even know of // any channels, then we should get the entire set back. preChanIDs := []uint64{1, 2, 3, 4} @@ -1716,14 +2027,14 @@ func TestFilterKnownChanIDs(t *testing.T) { } // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1813,24 +2124,22 @@ func TestFilterKnownChanIDs(t *testing.T) { func TestFilterChannelRange(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1947,24 +2256,22 @@ func TestFilterChannelRange(t *testing.T) { func TestFetchChanInfos(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1980,12 +2287,6 @@ func TestFetchChanInfos(t *testing.T) { edges := make([]ChannelEdge, 0, numChans) edgeQuery := make([]uint64, 0, numChans) for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, ) @@ -1998,7 +2299,7 @@ func TestFetchChanInfos(t *testing.T) { endTime = updateTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edge1.ChannelFlags = 0 edge1.Node = node2 @@ -2008,7 +2309,7 @@ func TestFetchChanInfos(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edge2.ChannelFlags = 1 edge2.Node = node1 @@ -2075,23 +2376,21 @@ func TestFetchChanInfos(t *testing.T) { func TestIncompleteChannelPolicies(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // Create two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2099,13 +2398,6 @@ func TestIncompleteChannelPolicies(t *testing.T) { t.Fatalf("unable to add node: %v", err) } - // Create channel between nodes. - txHash := sha256.Sum256([]byte{0}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(0), 0, 0, 0, node1, node2, ) @@ -2156,7 +2448,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { updateTime := time.Unix(1234, 0) edgePolicy := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edgePolicy.ChannelFlags = 0 edgePolicy.Node = node2 @@ -2171,7 +2463,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Create second policy and assert that both policies are reported // as present. edgePolicy = newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edgePolicy.ChannelFlags = 1 edgePolicy.Node = node1 @@ -2190,14 +2482,13 @@ func TestIncompleteChannelPolicies(t *testing.T) { func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -2207,14 +2498,14 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2229,7 +2520,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) edge1.ChannelFlags = 0 edge1.Node = node1 edge1.SigBytes = testSig.Serialize() @@ -2237,7 +2528,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("unable to update edge: %v", err) } - edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge2 := randEdgePolicy(chanID.ToUint64(), graph.db) edge2.ChannelFlags = 1 edge2.Node = node2 edge2.SigBytes = testSig.Serialize() @@ -2253,7 +2544,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { timestampSet[t] = struct{}{} } - err := kvdb.View(db, func(tx kvdb.RTx) error { + err := kvdb.View(graph.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2345,7 +2636,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { func TestPruneGraphNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2353,8 +2644,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll start off by inserting our source node, to ensure that it's // the only node left after we prune the graph. - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -2365,21 +2655,21 @@ func TestPruneGraphNodes(t *testing.T) { // With the source node inserted, we'll now add three nodes to the // channel graph, at the end of the scenario, only two of these nodes // should still be in the graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } - node3, err := createTestVertex(db) + node3, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2396,7 +2686,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll now insert an advertised edge, but it'll only be the edge that // points from the first to the second node. - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) edge1.ChannelFlags = 0 edge1.Node = node1 edge1.SigBytes = testSig.Serialize() @@ -2417,7 +2707,7 @@ func TestPruneGraphNodes(t *testing.T) { // Finally, we'll ensure that node3, the only fully unconnected node as // properly deleted from the graph and not another node in its place. - _, err = graph.FetchLightningNode(nil, node3.PubKeyBytes) + _, err = graph.FetchLightningNode(node3.PubKeyBytes) if err == nil { t.Fatalf("node 3 should have been deleted!") } @@ -2429,24 +2719,22 @@ func TestPruneGraphNodes(t *testing.T) { func TestAddChannelEdgeShellNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // To start, we'll create two nodes, and only add one of them to the // channel graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2460,7 +2748,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // Ensure that node1 was inserted as a full node, while node2 only has // a shell node present. - node1, err = graph.FetchLightningNode(nil, node1.PubKeyBytes) + node1, err = graph.FetchLightningNode(node1.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node1: %v", err) } @@ -2468,7 +2756,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { t.Fatalf("have shell announcement for node1, shouldn't") } - node2, err = graph.FetchLightningNode(nil, node2.PubKeyBytes) + node2, err = graph.FetchLightningNode(node2.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node2: %v", err) } @@ -2483,17 +2771,15 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with a single node that will be // removed shortly. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2553,44 +2839,41 @@ func TestNodeIsPublic(t *testing.T) { // We'll need to create a separate database and channel graph for each // participant to replicate real-world scenarios (private edges being in // some graphs but not others, etc.). - aliceDB, cleanUp, err := MakeTestDB() + aliceGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - aliceNode, err := createTestVertex(aliceDB) + aliceNode, err := createTestVertex(aliceGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - aliceGraph := aliceDB.ChannelGraph() if err := aliceGraph.SetSourceNode(aliceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - bobDB, cleanUp, err := MakeTestDB() + bobGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - bobNode, err := createTestVertex(bobDB) + bobNode, err := createTestVertex(bobGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - bobGraph := bobDB.ChannelGraph() if err := bobGraph.SetSourceNode(bobNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - carolDB, cleanUp, err := MakeTestDB() + carolGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - carolNode, err := createTestVertex(carolDB) + carolNode, err := createTestVertex(carolGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - carolGraph := carolDB.ChannelGraph() if err := carolGraph.SetSourceNode(carolNode); err != nil { t.Fatalf("unable to set source node: %v", err) } @@ -2602,7 +2885,7 @@ func TestNodeIsPublic(t *testing.T) { // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} - dbs := []*DB{aliceDB, bobDB, carolDB} + dbs := []kvdb.Backend{aliceGraph.db, bobGraph.db, carolGraph.db} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for i, graph := range graphs { for _, node := range nodes { @@ -2702,16 +2985,14 @@ func TestNodeIsPublic(t *testing.T) { func TestDisabledChannelIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - graph := db.ChannelGraph() - // Create first node and add it to the graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2720,7 +3001,7 @@ func TestDisabledChannelIDs(t *testing.T) { } // Create second node and add it to the graph. - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2729,7 +3010,7 @@ func TestDisabledChannelIDs(t *testing.T) { } // Adding a new channel edge to the graph. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2800,29 +3081,27 @@ func TestDisabledChannelIDs(t *testing.T) { func TestEdgePolicyMissingMaxHtcl(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2862,7 +3141,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // Attempting to deserialize these bytes should return an error. r := bytes.NewReader(stripped) - err = kvdb.View(db, func(tx kvdb.RTx) error { + err = kvdb.View(graph.db, func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) if nodes == nil { return ErrGraphNotFound @@ -2882,7 +3161,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { } // Put the stripped bytes in the DB. - err = kvdb.Update(db, func(tx kvdb.RwTx) error { + err = kvdb.Update(graph.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrEdgeNotFound @@ -2980,18 +3259,17 @@ func TestGraphZombieIndex(t *testing.T) { t.Parallel() // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to create test database: %v", err) } - graph := db.ChannelGraph() - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } @@ -3002,7 +3280,7 @@ func TestGraphZombieIndex(t *testing.T) { node1, node2 = node2, node1 } - edge, _, _ := createChannelEdge(db, node1, node2) + edge, _, _ := createChannelEdge(graph.db, node1, node2) if err := graph.AddChannelEdge(edge); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -3155,7 +3433,7 @@ func compareEdgePolicies(a, b *ChannelEdgePolicy) error { return nil } -// TestLightningNodeSigVerifcation checks that we can use the LightningNode's +// TestLightningNodeSigVerification checks that we can use the LightningNode's // pubkey to verify signatures. func TestLightningNodeSigVerification(t *testing.T) { t.Parallel() @@ -3183,13 +3461,13 @@ func TestLightningNodeSigVerification(t *testing.T) { } // Create a LightningNode from the same private key. - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - node, err := createLightningNode(db, priv) + node, err := createLightningNode(graph.db, priv) if err != nil { t.Fatalf("unable to create node: %v", err) } @@ -3233,21 +3511,20 @@ func TestComputeFee(t *testing.T) { func TestBatchedAddChannelEdge(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() require.Nil(t, err) defer cleanUp() - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.SetSourceNode(sourceNode) require.Nil(t, err) // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) require.Nil(t, err) - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) require.Nil(t, err) // In addition to the fake vertexes we create some fake channel @@ -3316,25 +3593,23 @@ func TestBatchedAddChannelEdge(t *testing.T) { func TestBatchedUpdateEdgePolicy(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() require.Nil(t, err) defer cleanUp() - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.AddLightningNode(node1) require.Nil(t, err) - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.AddLightningNode(node2) require.Nil(t, err) // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. @@ -3372,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { require.Nil(t, err) } } + +// BenchmarkForEachChannel is a benchmark test that measures the number of +// allocations and the total memory consumed by the full graph traversal. +func BenchmarkForEachChannel(b *testing.B) { + graph, cleanUp, err := MakeTestGraph() + require.Nil(b, err) + defer cleanUp() + + const numNodes = 100 + const numChannels = 4 + _, _ = fillTestGraph(b, graph, numNodes, numChannels) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var ( + totalCapacity btcutil.Amount + maxHTLCs lnwire.MilliSatoshi + ) + err := graph.ForEachNodeCacheable( + func(tx kvdb.RTx, n GraphCacheNode) error { + return n.ForEachChannel( + tx, func(tx kvdb.RTx, + info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + // We need to do something with + // the data here, otherwise the + // compiler is going to optimize + // this away, and we get bogus + // results. + totalCapacity += info.Capacity + maxHTLCs += policy.MaxHTLC + maxHTLCs += policy2.MaxHTLC + + return nil + }, + ) + }, + ) + require.NoError(b, err) + } +} diff --git a/channeldb/nodes.go b/channeldb/nodes.go index 88d98d6ae..ffc7414c5 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -56,12 +56,14 @@ type LinkNode struct { // authenticated connection for the stored identity public key. Addresses []net.Addr - db *DB + // db is the database instance this node was fetched from. This is used + // to sync back the node's state if it is updated. + db *LinkNodeDB } // NewLinkNode creates a new LinkNode from the provided parameters, which is -// backed by an instance of channeldb. -func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, +// backed by an instance of a link node DB. +func NewLinkNode(db *LinkNodeDB, bitNet wire.BitcoinNet, pub *btcec.PublicKey, addrs ...net.Addr) *LinkNode { return &LinkNode{ @@ -69,7 +71,7 @@ func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, IdentityPub: pub, LastSeen: time.Now(), Addresses: addrs, - db: d, + db: db, } } @@ -98,10 +100,9 @@ func (l *LinkNode) AddAddress(addr net.Addr) error { // Sync performs a full database sync which writes the current up-to-date data // within the struct to the database. func (l *LinkNode) Sync() error { - // Finally update the database by storing the link node and updating // any relevant indexes. - return kvdb.Update(l.db, func(tx kvdb.RwTx) error { + return kvdb.Update(l.db.backend, func(tx kvdb.RwTx) error { nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) if nodeMetaBucket == nil { return ErrLinkNodesNotFound @@ -127,15 +128,20 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error { return nodeMetaBucket.Put(nodePub, b.Bytes()) } +// LinkNodeDB is a database that keeps track of all link nodes. +type LinkNodeDB struct { + backend kvdb.Backend +} + // DeleteLinkNode removes the link node with the given identity from the // database. -func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { - return d.deleteLinkNode(tx, identity) +func (l *LinkNodeDB) DeleteLinkNode(identity *btcec.PublicKey) error { + return kvdb.Update(l.backend, func(tx kvdb.RwTx) error { + return deleteLinkNode(tx, identity) }, func() {}) } -func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { +func deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) if nodeMetaBucket == nil { return ErrLinkNodesNotFound @@ -148,9 +154,9 @@ func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { // FetchLinkNode attempts to lookup the data for a LinkNode based on a target // identity public key. If a particular LinkNode for the passed identity public // key cannot be found, then ErrNodeNotFound if returned. -func (d *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { +func (l *LinkNodeDB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { var linkNode *LinkNode - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(l.backend, func(tx kvdb.RTx) error { node, err := fetchLinkNode(tx, identity) if err != nil { return err @@ -191,10 +197,10 @@ func fetchLinkNode(tx kvdb.RTx, targetPub *btcec.PublicKey) (*LinkNode, error) { // FetchAllLinkNodes starts a new database transaction to fetch all nodes with // whom we have active channels with. -func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) { +func (l *LinkNodeDB) FetchAllLinkNodes() ([]*LinkNode, error) { var linkNodes []*LinkNode - err := kvdb.View(d, func(tx kvdb.RTx) error { - nodes, err := d.fetchAllLinkNodes(tx) + err := kvdb.View(l.backend, func(tx kvdb.RTx) error { + nodes, err := fetchAllLinkNodes(tx) if err != nil { return err } @@ -213,7 +219,7 @@ func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) { // fetchAllLinkNodes uses an existing database transaction to fetch all nodes // with whom we have active channels with. -func (d *DB) fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) { +func fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) { nodeMetaBucket := tx.ReadBucket(nodeInfoBucket) if nodeMetaBucket == nil { return nil, ErrLinkNodesNotFound diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 0d649d431..8f60a7986 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -13,12 +13,14 @@ import ( func TestLinkNodeEncodeDecode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First we'll create some initial data to use for populating our test // LinkNode instances. _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) @@ -34,8 +36,8 @@ func TestLinkNodeEncodeDecode(t *testing.T) { // Create two fresh link node instances with the above dummy data, then // fully sync both instances to disk. - node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) - node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) + node1 := NewLinkNode(cdb.linkNodeDB, wire.MainNet, pub1, addr1) + node2 := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pub2, addr2) if err := node1.Sync(); err != nil { t.Fatalf("unable to sync node: %v", err) } @@ -46,7 +48,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { // Fetch all current link nodes from the database, they should exactly // match the two created above. originalNodes := []*LinkNode{node2, node1} - linkNodes, err := cdb.FetchAllLinkNodes() + linkNodes, err := cdb.linkNodeDB.FetchAllLinkNodes() if err != nil { t.Fatalf("unable to fetch nodes: %v", err) } @@ -82,7 +84,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { } // Fetch the same node from the database according to its public key. - node1DB, err := cdb.FetchLinkNode(pub1) + node1DB, err := cdb.linkNodeDB.FetchLinkNode(pub1) if err != nil { t.Fatalf("unable to find node: %v", err) } @@ -110,31 +112,33 @@ func TestLinkNodeEncodeDecode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1337, } - linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) + linkNode := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pubKey, addr) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to write link node to db: %v", err) } - if _, err := cdb.FetchLinkNode(pubKey); err != nil { + if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err != nil { t.Fatalf("unable to find link node: %v", err) } - if err := cdb.DeleteLinkNode(pubKey); err != nil { + if err := cdb.linkNodeDB.DeleteLinkNode(pubKey); err != nil { t.Fatalf("unable to delete link node from db: %v", err) } - if _, err := cdb.FetchLinkNode(pubKey); err == nil { + if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err == nil { t.Fatal("should not have found link node in db, but did") } } diff --git a/channeldb/options.go b/channeldb/options.go index ceb29bf7b..ad22fa8ed 100644 --- a/channeldb/options.go +++ b/channeldb/options.go @@ -17,6 +17,12 @@ const ( // in order to reply to gossip queries. This produces a cache size of // around 40MB. DefaultChannelCacheSize = 20000 + + // DefaultPreAllocCacheNumNodes is the default number of channels we + // assume for mainnet for pre-allocating the graph cache. As of + // September 2021, there currently are 14k nodes in a strictly pruned + // graph, so we choose a number that is slightly higher. + DefaultPreAllocCacheNumNodes = 15000 ) // Options holds parameters for tuning and customizing a channeldb.DB. @@ -35,6 +41,10 @@ type Options struct { // wait before attempting to commit a pending set of updates. BatchCommitInterval time.Duration + // PreAllocCacheNumNodes is the number of nodes we expect to be in the + // graph cache, so we can pre-allocate the map accordingly. + PreAllocCacheNumNodes int + // clock is the time source used by the database. clock clock.Clock @@ -52,9 +62,10 @@ func DefaultOptions() Options { AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, DBTimeout: kvdb.DefaultDBTimeout, }, - RejectCacheSize: DefaultRejectCacheSize, - ChannelCacheSize: DefaultChannelCacheSize, - clock: clock.NewDefaultClock(), + RejectCacheSize: DefaultRejectCacheSize, + ChannelCacheSize: DefaultChannelCacheSize, + PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes, + clock: clock.NewDefaultClock(), } } @@ -75,6 +86,13 @@ func OptionSetChannelCacheSize(n int) OptionModifier { } } +// OptionSetPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n. +func OptionSetPreAllocCacheNumNodes(n int) OptionModifier { + return func(o *Options) { + o.PreAllocCacheNumNodes = n + } +} + // OptionSetSyncFreelist allows the database to sync its freelist. func OptionSetSyncFreelist(b bool) OptionModifier { return func(o *Options) { diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index e8a09b758..7bb53e179 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -36,12 +36,12 @@ type WaitingProofStore struct { // cache is used in order to reduce the number of redundant get // calls, when object isn't stored in it. cache map[WaitingProofKey]struct{} - db *DB + db kvdb.Backend mu sync.RWMutex } // NewWaitingProofStore creates new instance of proofs storage. -func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { +func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) { s := &WaitingProofStore{ db: db, } diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 74c2b15eb..2cf6015c4 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -17,7 +17,7 @@ type ChannelNotifier struct { ntfnServer *subscribe.Server - chanDB *channeldb.DB + chanDB *channeldb.ChannelStateDB } // PendingOpenChannelEvent represents a new event where a new channel has @@ -76,7 +76,7 @@ type FullyResolvedChannelEvent struct { // New creates a new channel notifier. The ChannelNotifier gets channel // events from peers and from the chain arbitrator, and dispatches them to // its clients. -func New(chanDB *channeldb.DB) *ChannelNotifier { +func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier { return &ChannelNotifier{ ntfnServer: subscribe.NewServer(), chanDB: chanDB, diff --git a/chanrestore.go b/chanrestore.go index 7527499cd..cd68b5077 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -34,7 +34,7 @@ const ( // need the secret key chain in order obtain the prior shachain root so we can // verify the DLP protocol as initiated by the remote node. type chanDBRestorer struct { - db *channeldb.DB + db *channeldb.ChannelStateDB secretKeys keychain.SecretKeyRing diff --git a/contractcourt/breacharbiter.go b/contractcourt/breacharbiter.go index 112aa5bce..3253e0009 100644 --- a/contractcourt/breacharbiter.go +++ b/contractcourt/breacharbiter.go @@ -136,7 +136,7 @@ type BreachConfig struct { // DB provides access to the user's channels, allowing the breach // arbiter to determine the current state of a user's channels, and how // it should respond to channel closure. - DB *channeldb.DB + DB *channeldb.ChannelStateDB // Estimator is used by the breach arbiter to determine an appropriate // fee level when generating, signing, and broadcasting sweep @@ -1432,11 +1432,11 @@ func (b *BreachArbiter) sweepSpendableOutputsTxn(txWeight int64, // store is to ensure that we can recover from a restart in the middle of a // breached contract retribution. type RetributionStore struct { - db *channeldb.DB + db kvdb.Backend } // NewRetributionStore creates a new instance of a RetributionStore. -func NewRetributionStore(db *channeldb.DB) *RetributionStore { +func NewRetributionStore(db kvdb.Backend) *RetributionStore { return &RetributionStore{ db: db, } diff --git a/contractcourt/breacharbiter_test.go b/contractcourt/breacharbiter_test.go index 0d423584c..61e819391 100644 --- a/contractcourt/breacharbiter_test.go +++ b/contractcourt/breacharbiter_test.go @@ -987,7 +987,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter, contractBreaches := make(chan *ContractBreachEvent) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -1164,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) { assertNotPendingClosed(t, alice) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -2075,7 +2075,7 @@ func assertNoArbiterBreach(t *testing.T, brar *BreachArbiter, // assertBrarCleanup blocks until the given channel point has been removed the // retribution store and the channel is fully closed in the database. func assertBrarCleanup(t *testing.T, brar *BreachArbiter, - chanPoint *wire.OutPoint, db *channeldb.DB) { + chanPoint *wire.OutPoint, db *channeldb.ChannelStateDB) { t.Helper() @@ -2174,7 +2174,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent, notifier := mock.MakeMockSpendNotifier() ba := NewBreachArbiter(&BreachConfig{ CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {}, - DB: db, + DB: db.ChannelStateDB(), Estimator: chainfee.NewStaticEstimator(12500, 0), GenSweepScript: func() ([]byte, error) { return nil, nil }, ContractBreaches: contractBreaches, @@ -2375,7 +2375,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -2393,7 +2393,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 426382dd3..aeeff69f9 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -258,7 +258,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, // same instance that is used by the link. chanPoint := a.channel.FundingOutpoint - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -301,7 +303,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) // Now that we know the link can't mutate the channel // state, we'll read the channel from disk the target // channel according to its channel point. - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -422,7 +426,7 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { // First, we'll we'll mark the channel as fully closed from the PoV of // the channel source. - err := c.chanSource.MarkChanFullyClosed(&chanPoint) + err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint) if err != nil { log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+ "fully closed: %v", chanPoint, err) @@ -480,7 +484,7 @@ func (c *ChainArbitrator) Start() error { // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.FetchAllChannels() + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() if err != nil { return err } @@ -538,7 +542,9 @@ func (c *ChainArbitrator) Start() error { // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.FetchClosedChannels(true) + closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( + true, + ) if err != nil { return err } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index e197c0b09..cb1648065 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -49,7 +49,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { // We manually set the db here to make sure all channels are // synced to the same db. - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -165,7 +165,7 @@ func TestResolveContract(t *testing.T) { } defer cleanup() channel := newChannel.State() - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18556, @@ -205,7 +205,7 @@ func TestResolveContract(t *testing.T) { // While the resolver are active, we'll now remove the channel from the // database (mark is as closed). - err = db.AbandonChannel(&channel.FundingOutpoint, 4) + err = db.ChannelStateDB().AbandonChannel(&channel.FundingOutpoint, 4) if err != nil { t.Fatalf("unable to remove channel: %v", err) } diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 11f23d8cc..0023402c1 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -58,7 +58,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( *channeldb.OpenChannel, func(), error) { // Make a copy of the DB. - dbFile := filepath.Join(state.Db.Path(), "channel.db") + dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db") tempDbPath, err := ioutil.TempDir("", "past-state") if err != nil { return nil, nil, err @@ -81,7 +81,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( return nil, nil, err } - chans, err := newDb.FetchAllChannels() + chans, err := newDb.ChannelStateDB().FetchAllChannels() if err != nil { cleanup() return nil, nil, err diff --git a/discovery/message_store.go b/discovery/message_store.go index 4d5f9b205..40f2df78a 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -59,7 +58,7 @@ type GossipMessageStore interface { // version of a message (like in the case of multiple ChannelUpdate's) for a // channel with a peer. type MessageStore struct { - db *channeldb.DB + db kvdb.Backend } // A compile-time assertion to ensure messageStore implements the @@ -67,8 +66,8 @@ type MessageStore struct { var _ GossipMessageStore = (*MessageStore)(nil) // NewMessageStore creates a new message store backed by a channeldb instance. -func NewMessageStore(db *channeldb.DB) (*MessageStore, error) { - err := kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error { +func NewMessageStore(db kvdb.Backend) (*MessageStore, error) { + err := kvdb.Batch(db, func(tx kvdb.RwTx) error { _, err := tx.CreateTopLevelBucket(messageStoreBucket) return err }) @@ -124,7 +123,7 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore @@ -145,7 +144,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore diff --git a/docs/release-notes/release-notes-0.14.0.md b/docs/release-notes/release-notes-0.14.0.md index 92699bf75..ac1619312 100644 --- a/docs/release-notes/release-notes-0.14.0.md +++ b/docs/release-notes/release-notes-0.14.0.md @@ -59,6 +59,18 @@ in `lnd`, saving developer time and limiting the potential for bugs. Instructions for enabling Postgres can be found in [docs/postgres.md](../postgres.md). +### In-memory path finding + +Finding a path through the channel graph for sending a payment doesn't involve +any database queries anymore. The [channel graph is now kept fully +in-memory](https://github.com/lightningnetwork/lnd/pull/5642) for up a massive +performance boost when calling `QueryRoutes` or any of the `SendPayment` +variants. Keeping the full graph in memory naturally comes with increased RAM +usage. Users running `lnd` on low-memory systems are advised to run with the +`routing.strictgraphpruning=true` configuration option that more aggressively +removes zombie channels from the graph, reducing the number of channels that +need to be kept in memory. + ## Protocol Extensions ### Explicit Channel Negotiation diff --git a/funding/manager.go b/funding/manager.go index 9725630ee..47b7634fa 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,7 +23,6 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" @@ -550,19 +549,6 @@ const ( addedToRouterGraph ) -var ( - // channelOpeningStateBucket is the database bucket used to store the - // channelOpeningState for each channel that is currently in the process - // of being opened. - channelOpeningStateBucket = []byte("channelOpeningState") - - // ErrChannelNotFound is an error returned when a channel is not known - // to us. In this case of the fundingManager, this error is returned - // when the channel in question is not considered being in an opening - // state. - ErrChannelNotFound = fmt.Errorf("channel not found") -) - // NewFundingManager creates and initializes a new instance of the // fundingManager. func NewFundingManager(cfg Config) (*Manager, error) { @@ -887,7 +873,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, channelState, shortChanID, err := f.getChannelOpeningState( &channel.FundingOutpoint, ) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Channel not in fundingManager's opening database, // meaning it was successfully announced to the // network. @@ -3551,26 +3537,20 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey { // chanPoint to the channelOpeningStateBucket. func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, state channelOpeningState, shortChanID *lnwire.ShortChannelID) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) - if err != nil { - return err - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - var outpointBytes bytes.Buffer - if err = WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - // Save state and the uint64 representation of the shortChanID - // for later use. - scratch := make([]byte, 10) - byteOrder.PutUint16(scratch[:2], uint16(state)) - byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) - - return bucket.Put(outpointBytes.Bytes(), scratch) - }, func() {}) + // Save state and the uint64 representation of the shortChanID + // for later use. + scratch := make([]byte, 10) + byteOrder.PutUint16(scratch[:2], uint16(state)) + byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) + return f.cfg.Wallet.Cfg.Database.SaveChannelOpeningState( + outpointBytes.Bytes(), scratch, + ) } // getChannelOpeningState fetches the channelOpeningState for the provided @@ -3579,51 +3559,31 @@ func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) ( channelOpeningState, *lnwire.ShortChannelID, error) { - var state channelOpeningState - var shortChanID lnwire.ShortChannelID - err := kvdb.View(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RTx) error { + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return 0, nil, err + } - bucket := tx.ReadBucket(channelOpeningStateBucket) - if bucket == nil { - // If the bucket does not exist, it means we never added - // a channel to the db, so return ErrChannelNotFound. - return ErrChannelNotFound - } - - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - value := bucket.Get(outpointBytes.Bytes()) - if value == nil { - return ErrChannelNotFound - } - - state = channelOpeningState(byteOrder.Uint16(value[:2])) - shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) - return nil - }, func() {}) + value, err := f.cfg.Wallet.Cfg.Database.GetChannelOpeningState( + outpointBytes.Bytes(), + ) if err != nil { return 0, nil, err } + state := channelOpeningState(byteOrder.Uint16(value[:2])) + shortChanID := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) return state, &shortChanID, nil } // deleteChannelOpeningState removes any state for chanPoint from the database. func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - bucket := tx.ReadWriteBucket(channelOpeningStateBucket) - if bucket == nil { - return fmt.Errorf("bucket not found") - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - return bucket.Delete(outpointBytes.Bytes()) - }, func() {}) + return f.cfg.Wallet.Cfg.Database.DeleteChannelOpeningState( + outpointBytes.Bytes(), + ) } diff --git a/funding/manager_test.go b/funding/manager_test.go index 636f4b0fb..8dcb97868 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -262,7 +262,7 @@ func (n *testNode) AddNewChannel(channel *channeldb.OpenChannel, } } -func createTestWallet(cdb *channeldb.DB, netParams *chaincfg.Params, +func createTestWallet(cdb *channeldb.ChannelStateDB, netParams *chaincfg.Params, notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, signer input.Signer, keyRing keychain.SecretKeyRing, bio lnwallet.BlockChainIO, @@ -330,11 +330,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, } dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } + cdb := fullDB.ChannelStateDB() + keyRing := &mock.SecretKeyRing{ RootKey: alicePrivKey, } @@ -923,12 +925,12 @@ func assertDatabaseState(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err != nil && err != ErrChannelNotFound { + if err != nil && err != channeldb.ErrChannelNotFound { t.Fatalf("unable to get channel state: %v", err) } // If we found the channel, check if it had the expected state. - if err != ErrChannelNotFound && state == expectedState { + if err != channeldb.ErrChannelNotFound && state == expectedState { // Got expected state, return with success. return } @@ -1166,7 +1168,7 @@ func assertErrChannelNotFound(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Got expected state, return with success. return } else if err != nil { diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 951c922f0..d5bb5f376 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -199,9 +199,16 @@ type circuitMap struct { // parameterize an instance of circuitMap. type CircuitMapConfig struct { // DB provides the persistent storage engine for the circuit map. - // TODO(conner): create abstraction to allow for the substitution of - // other persistence engines. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // ExtractErrorEncrypter derives the shared secret used to encrypt // errors from the obfuscator's ephemeral public key. @@ -296,7 +303,7 @@ func (cm *circuitMap) cleanClosedChannels() error { // Find closed channels and cache their ShortChannelIDs into a map. // This map will be used for looking up relative circuits and keystones. - closedChannels, err := cm.cfg.DB.FetchClosedChannels(false) + closedChannels, err := cm.cfg.FetchClosedChannels(false) if err != nil { return err } @@ -629,7 +636,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { // channels. Therefore, it must be called before any links are created to avoid // interfering with normal operation. func (cm *circuitMap) trimAllOpenCircuits() error { - activeChannels, err := cm.cfg.DB.FetchAllOpenChannels() + activeChannels, err := cm.cfg.FetchAllOpenChannels() if err != nil { return err } @@ -860,7 +867,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) ( // Write the entire batch of circuits to the persistent circuit bucket // using bolt's Batch write. This method must be called from multiple, // distinct goroutines to have any impact on performance. - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap @@ -1091,7 +1098,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { } cm.mtx.Unlock() - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { for _, circuit := range removedCircuits { // If this htlc made it to an outgoing link, load the // keystone bucket from which we will remove the diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index d3ee7b4fe..fed07958b 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -103,8 +103,11 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig, onionProcessor := newOnionProcessor(t) + db := makeCircuitDB(t, "") circuitMapCfg := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, ""), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, } @@ -634,13 +637,17 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB { func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( *htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) { - // Record the current temp path and close current db. - dbPath := cfg.DB.Path() + // Record the current temp path and close current db. We know we have + // a full channeldb.DB here since we created it just above. + dbPath := cfg.DB.(*channeldb.DB).Path() cfg.DB.Close() // Reinitialize circuit map with same db path. + db := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, dbPath), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, } cm2, err := htlcswitch.NewCircuitMap(cfg2) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 5cb97482e..4d9e10433 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1938,7 +1938,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( pCache := newMockPreimageCache() - aliceDb := aliceLc.channel.State().Db + aliceDb := aliceLc.channel.State().Db.GetParentDB() aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, nil, nil, nil, err @@ -4438,7 +4438,7 @@ func (h *persistentLinkHarness) restartLink( pCache = newMockPreimageCache() ) - aliceDb := aliceChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() aliceSwitch := h.coreLink.cfg.Switch if restartSwitch { var err error diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 3283ec646..ac1a47dc2 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -170,8 +170,10 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) } cfg := Config{ - DB: db, - SwitchPackager: channeldb.NewSwitchPackager(), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + SwitchPackager: channeldb.NewSwitchPackager(), FwdingLog: &mockForwardingLog{ events: make(map[time.Time]channeldb.ForwardingEvent), }, diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 2bd35f60a..8d6cb5b3a 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -83,7 +83,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) { // is back. The Switch will checkpoint any received result to the store, and // the store will keep results and notify the callers about them. type networkResultStore struct { - db *channeldb.DB + backend kvdb.Backend // results is a map from paymentIDs to channels where subscribers to // payment results will be notified. @@ -96,9 +96,9 @@ type networkResultStore struct { paymentIDMtx *multimutex.Mutex } -func newNetworkResultStore(db *channeldb.DB) *networkResultStore { +func newNetworkResultStore(db kvdb.Backend) *networkResultStore { return &networkResultStore{ - db: db, + backend: db, results: make(map[uint64][]chan *networkResult), paymentIDMtx: multimutex.NewMutex(), } @@ -126,7 +126,7 @@ func (store *networkResultStore) storeResult(paymentID uint64, var paymentIDBytes [8]byte binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID) - err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) @@ -171,7 +171,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( resultChan = make(chan *networkResult, 1) ) - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, paymentID) switch { @@ -219,7 +219,7 @@ func (store *networkResultStore) getResult(pid uint64) ( *networkResult, error) { var result *networkResult - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, pid) return err @@ -260,7 +260,7 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) { // concurrently while this process is ongoing, as its result might end up being // deleted. func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error { - return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Update(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index daa429b39..99c46737a 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -130,9 +130,18 @@ type Config struct { // subsystem. LocalChannelClose func(pubKey []byte, request *ChanClose) - // DB is the channeldb instance that will be used to back the switch's + // DB is the database backend that will be used to back the switch's // persistent circuit map. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // SwitchPackager provides access to the forwarding packages of all // active channels. This gives the switch the ability to read arbitrary @@ -294,6 +303,8 @@ type Switch struct { func New(cfg Config, currentHeight uint32) (*Switch, error) { circuitMap, err := NewCircuitMap(&CircuitMapConfig{ DB: cfg.DB, + FetchAllOpenChannels: cfg.FetchAllOpenChannels, + FetchClosedChannels: cfg.FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, }) if err != nil { @@ -1455,7 +1466,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) { // we're the originator of the payment, so the link stops attempting to // re-broadcast. func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error { - return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error { return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...) }) } @@ -1859,7 +1870,7 @@ func (s *Switch) Start() error { // forwarding packages and reforwards any Settle or Fail HTLCs found. This is // used to resurrect the switch's mailboxes after a restart. func (s *Switch) reforwardResponses() error { - openChannels, err := s.cfg.DB.FetchAllOpenChannels() + openChannels, err := s.cfg.FetchAllOpenChannels() if err != nil { return err } @@ -2122,6 +2133,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { return link, nil } +// GetLinkByShortID attempts to return the link which possesses the target short +// channel ID. +func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, + error) { + + s.indexMtx.RLock() + defer s.indexMtx.RUnlock() + + return s.getLinkByShortID(chanID) +} + // getLinkByShortID attempts to return the link which possesses the target // short channel ID. // diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index d33daff8f..eaf2aa99c 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -306,7 +306,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, ShortChannelID: chanID, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, } @@ -325,7 +325,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: bobCommit, RemoteCommitment: bobCommit, ShortChannelID: chanID, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), } @@ -384,7 +384,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreAlice := func() (*lnwallet.LightningChannel, error) { - aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err := dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -394,7 +395,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err = dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch alice "+ "channel: %v", err) @@ -428,7 +430,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreBob := func() (*lnwallet.LightningChannel, error) { - bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err := dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -438,7 +441,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err = dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch bob "+ "channel: %v", err) @@ -950,9 +954,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, secondBobChannel, carolChannel *lnwallet.LightningChannel, startingHeight uint32, opts ...serverOption) *threeHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := firstBobChannel.State().Db - carolDb := carolChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := firstBobChannel.State().Db.GetParentDB() + carolDb := carolChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() @@ -1201,8 +1205,8 @@ func newTwoHopNetwork(t testing.TB, aliceChannel, bobChannel *lnwallet.LightningChannel, startingHeight uint32) *twoHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := bobChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := bobChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() diff --git a/lnd.go b/lnd.go index 8bf4a9fc4..23f9b835e 100644 --- a/lnd.go +++ b/lnd.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/wallet" @@ -697,7 +698,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error BtcdMode: cfg.BtcdMode, LtcdMode: cfg.LtcdMode, HeightHintDB: dbs.heightHintDB, - ChanStateDB: dbs.chanStateDB, + ChanStateDB: dbs.chanStateDB.ChannelStateDB(), PrivateWalletPw: privateWalletPw, PublicWalletPw: publicWalletPw, Birthday: walletInitParams.Birthday, @@ -1679,14 +1680,27 @@ func initializeDatabases(ctx context.Context, "instances") } - // Otherwise, we'll open two instances, one for the state we only need - // locally, and the other for things we want to ensure are replicated. - dbs.graphDB, err = channeldb.CreateWithBackend( - databaseBackends.GraphDB, + dbOptions := []channeldb.OptionModifier{ channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize), channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize), channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval), channeldb.OptionDryRunMigration(cfg.DryRunMigration), + } + + // We want to pre-allocate the channel graph cache according to what we + // expect for mainnet to speed up memory allocation. + if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name { + dbOptions = append( + dbOptions, channeldb.OptionSetPreAllocCacheNumNodes( + channeldb.DefaultPreAllocCacheNumNodes, + ), + ) + } + + // Otherwise, we'll open two instances, one for the state we only need + // locally, and the other for things we want to ensure are replicated. + dbs.graphDB, err = channeldb.CreateWithBackend( + databaseBackends.GraphDB, dbOptions..., ) switch { // Give the DB a chance to dry run the migration. Since we know that diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 4e88ae0c1..193f3a63a 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -56,7 +56,7 @@ type AddInvoiceConfig struct { // ChanDB is a global boltdb instance which is needed to access the // channel graph. - ChanDB *channeldb.DB + ChanDB *channeldb.ChannelStateDB // Graph holds a reference to the ChannelGraph database. Graph *channeldb.ChannelGraph diff --git a/lnrpc/invoicesrpc/config_active.go b/lnrpc/invoicesrpc/config_active.go index 01f595abc..c70f228a9 100644 --- a/lnrpc/invoicesrpc/config_active.go +++ b/lnrpc/invoicesrpc/config_active.go @@ -51,7 +51,7 @@ type Config struct { // ChanStateDB is a possibly replicated db instance which contains the // channel state - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // GenInvoiceFeatures returns a feature containing feature bits that // should be advertised on freshly generated invoices. diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 1814c6358..d39ff7a7d 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -55,7 +55,7 @@ type RouterBackend struct { FindRoute func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) MissionControl MissionControl diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 26a44cbb9..1b05d5f81 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -126,7 +126,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, findRoute := func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, _ record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { if int64(amt) != amtSat*1000 { diff --git a/lntest/itest/lnd_channel_graph_test.go b/lntest/itest/lnd_channel_graph_test.go index c7040e02e..747f018fd 100644 --- a/lntest/itest/lnd_channel_graph_test.go +++ b/lntest/itest/lnd_channel_graph_test.go @@ -25,24 +25,20 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { ctxb := context.Background() // Create two fresh nodes and open a channel between them. - alice := net.NewNode( - t.t, "Alice", []string{ - "--minbackoff=10s", - "--chan-enable-timeout=1.5s", - "--chan-disable-timeout=3s", - "--chan-status-sample-interval=.5s", - }, - ) + alice := net.NewNode(t.t, "Alice", []string{ + "--minbackoff=10s", + "--chan-enable-timeout=1.5s", + "--chan-disable-timeout=3s", + "--chan-status-sample-interval=.5s", + }) defer shutdownAndAssert(net, t, alice) - bob := net.NewNode( - t.t, "Bob", []string{ - "--minbackoff=10s", - "--chan-enable-timeout=1.5s", - "--chan-disable-timeout=3s", - "--chan-status-sample-interval=.5s", - }, - ) + bob := net.NewNode(t.t, "Bob", []string{ + "--minbackoff=10s", + "--chan-enable-timeout=1.5s", + "--chan-disable-timeout=3s", + "--chan-status-sample-interval=.5s", + }) defer shutdownAndAssert(net, t, bob) // Connect Alice to Bob. @@ -55,36 +51,32 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // being the sole funder of the channel. chanAmt := btcutil.Amount(100000) chanPoint := openChannelAndAssert( - t, net, alice, bob, - lntest.OpenChannelParams{ + t, net, alice, bob, lntest.OpenChannelParams{ Amt: chanAmt, }, ) // Wait for Alice and Bob to receive the channel edge from the // funding manager. - ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() err := alice.WaitForNetworkChannelOpen(ctxt, chanPoint) - if err != nil { - t.Fatalf("alice didn't see the alice->bob channel before "+ - "timeout: %v", err) - } + require.NoError(t.t, err, "alice didn't see the alice->bob channel") - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) err = bob.WaitForNetworkChannelOpen(ctxt, chanPoint) - if err != nil { - t.Fatalf("bob didn't see the bob->alice channel before "+ - "timeout: %v", err) - } + require.NoError(t.t, err, "bob didn't see the alice->bob channel") - // Launch a node for Carol which will connect to Alice and Bob in - // order to receive graph updates. This will ensure that the - // channel updates are propagated throughout the network. + // Launch a node for Carol which will connect to Alice and Bob in order + // to receive graph updates. This will ensure that the channel updates + // are propagated throughout the network. carol := net.NewNode(t.t, "Carol", nil) defer shutdownAndAssert(net, t, carol) + // Connect both Alice and Bob to the new node Carol, so she can sync her + // graph. net.ConnectNodes(t.t, alice, carol) net.ConnectNodes(t.t, bob, carol) + waitForGraphSync(t, carol) // assertChannelUpdate checks that the required policy update has // happened on the given node. @@ -109,12 +101,11 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { ChanPoint: chanPoint, Action: action, } - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + _, err = node.RouterClient.UpdateChanStatus(ctxt, req) - if err != nil { - t.Fatalf("unable to call UpdateChanStatus for %s's node: %v", - node.Name(), err) - } + require.NoErrorf(t.t, err, "UpdateChanStatus") } // assertEdgeDisabled ensures that a given node has the correct @@ -122,26 +113,30 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { assertEdgeDisabled := func(node *lntest.HarnessNode, chanPoint *lnrpc.ChannelPoint, disabled bool) { - var predErr error - err = wait.Predicate(func() bool { + outPoint, err := lntest.MakeOutpoint(chanPoint) + require.NoError(t.t, err) + + err = wait.NoError(func() error { req := &lnrpc.ChannelGraphRequest{ IncludeUnannounced: true, } - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + chanGraph, err := node.DescribeGraph(ctxt, req) if err != nil { - predErr = fmt.Errorf("unable to query node %v's graph: %v", node, err) - return false + return fmt.Errorf("unable to query node %v's "+ + "graph: %v", node, err) } numEdges := len(chanGraph.Edges) if numEdges != 1 { - predErr = fmt.Errorf("expected to find 1 edge in the graph, found %d", numEdges) - return false + return fmt.Errorf("expected to find 1 edge in "+ + "the graph, found %d", numEdges) } edge := chanGraph.Edges[0] - if edge.ChanPoint != chanPoint.GetFundingTxidStr() { - predErr = fmt.Errorf("expected chan_point %v, got %v", - chanPoint.GetFundingTxidStr(), edge.ChanPoint) + if edge.ChanPoint != outPoint.String() { + return fmt.Errorf("expected chan_point %v, "+ + "got %v", outPoint, edge.ChanPoint) } var policy *lnrpc.RoutingPolicy if node.PubKeyStr == edge.Node1Pub { @@ -150,15 +145,14 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { policy = edge.Node2Policy } if disabled != policy.Disabled { - predErr = fmt.Errorf("expected policy.Disabled to be %v, "+ - "but policy was %v", disabled, policy) - return false + return fmt.Errorf("expected policy.Disabled "+ + "to be %v, but policy was %v", disabled, + policy) } - return true + + return nil }, defaultTimeout) - if err != nil { - t.Fatalf("%v", predErr) - } + require.NoError(t.t, err) } // When updating the state of the channel between Alice and Bob, we @@ -193,9 +187,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // disconnections from automatically disabling the channel again // (we don't want to clutter the network with channels that are // falsely advertised as enabled when they don't work). - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) expectedPolicy.Disabled = true assertChannelUpdate(alice, expectedPolicy) assertChannelUpdate(bob, expectedPolicy) @@ -217,9 +209,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { expectedPolicy.Disabled = true assertChannelUpdate(alice, expectedPolicy) - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) // Bob sends a "Disabled = true" update upon detecting the // disconnect. @@ -237,9 +227,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // note the asymmetry between manual enable and manual disable! assertEdgeDisabled(alice, chanPoint, true) - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) // Bob sends a "Disabled = true" update upon detecting the // disconnect. diff --git a/lnwallet/config.go b/lnwallet/config.go index a73120c02..cf7f3f4b8 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -18,7 +18,7 @@ type Config struct { // Database is a wrapper around a namespace within boltdb reserved for // ln-based wallet metadata. See the 'channeldb' package for further // information. - Database *channeldb.DB + Database *channeldb.ChannelStateDB // Notifier is used by in order to obtain notifications about funding // transaction reaching a specified confirmation depth, and to catch diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index 9c251da04..66da4f587 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -327,13 +327,13 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } cfg := lnwallet.Config{ - Database: cdb, + Database: fullDB.ChannelStateDB(), Notifier: notifier, SecretKeyRing: keyRing, WalletController: wc, @@ -2944,11 +2944,11 @@ func clearWalletStates(a, b *lnwallet.LightningWallet) error { a.ResetReservations() b.ResetReservations() - if err := a.Cfg.Database.Wipe(); err != nil { + if err := a.Cfg.Database.GetParentDB().Wipe(); err != nil { return err } - return b.Cfg.Database.Wipe() + return b.Cfg.Database.GetParentDB().Wipe() } func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error { diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index dcd9e202d..52bf53cfd 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -323,7 +323,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceLocalCommit, RemoteCommitment: aliceRemoteCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: testTx, } @@ -341,7 +341,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobLocalCommit, RemoteCommitment: bobRemoteCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 726996084..a0b6c4311 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -940,7 +940,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: remoteCommit, RemoteCommitment: remoteCommit, - Db: dbRemote, + Db: dbRemote.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } @@ -958,7 +958,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: localCommit, RemoteCommitment: localCommit, - Db: dbLocal, + Db: dbLocal.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } diff --git a/peer/brontide.go b/peer/brontide.go index 60c41af6f..9c6df6734 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -185,7 +185,7 @@ type Config struct { InterceptSwitch *htlcswitch.InterceptableSwitch // ChannelDB is used to fetch opened channels, and closed channels. - ChannelDB *channeldb.DB + ChannelDB *channeldb.ChannelStateDB // ChannelGraph is a pointer to the channel graph which is used to // query information about the set of known active channels. diff --git a/peer/test_utils.go b/peer/test_utils.go index 3ce1cbe03..ac5f5f5ab 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -229,7 +229,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -246,7 +246,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } @@ -321,7 +321,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanStatusSampleInterval: 30 * time.Second, ChanEnableTimeout: chanActiveTimeout, ChanDisableTimeout: 2 * time.Minute, - DB: dbAlice, + DB: dbAlice.ChannelStateDB(), Graph: dbAlice.ChannelGraph(), MessageSigner: nodeSignerAlice, OurPubKey: aliceKeyPub, @@ -359,7 +359,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), - ChannelDB: dbAlice, + ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, Wallet: wallet, ChainNotifier: notifier, diff --git a/routing/graph.go b/routing/graph.go index 83f06807e..7e0ba65b2 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -2,7 +2,6 @@ package routing import ( "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -10,10 +9,10 @@ import ( // routingGraph is an abstract interface that provides information about nodes // and edges to pathfinding. type routingGraph interface { - // forEachNodeChannel calls the callback for every channel of the given node. + // forEachNodeChannel calls the callback for every channel of the given + // node. forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error + cb func(channel *channeldb.DirectedChannel) error) error // sourceNode returns the source node of the graph. sourceNode() route.Vertex @@ -22,59 +21,44 @@ type routingGraph interface { fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// dbRoutingTx is a routingGraph implementation that retrieves from the +// CachedGraph is a routingGraph implementation that retrieves from the // database. -type dbRoutingTx struct { +type CachedGraph struct { graph *channeldb.ChannelGraph - tx kvdb.RTx source route.Vertex } -// newDbRoutingTx instantiates a new db-connected routing graph. It implictly +// A compile time assertion to make sure CachedGraph implements the routingGraph +// interface. +var _ routingGraph = (*CachedGraph)(nil) + +// NewCachedGraph instantiates a new db-connected routing graph. It implictly // instantiates a new read transaction. -func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { +func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { sourceNode, err := graph.SourceNode() if err != nil { return nil, err } - tx, err := graph.Database().BeginReadTx() - if err != nil { - return nil, err - } - - return &dbRoutingTx{ + return &CachedGraph{ graph: graph, - tx: tx, source: sourceNode.PubKeyBytes, }, nil } -// close closes the underlying db transaction. -func (g *dbRoutingTx) close() error { - return g.tx.Rollback() -} - // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { +func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, + cb func(channel *channeldb.DirectedChannel) error) error { - txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo, - p1, p2 *channeldb.ChannelEdgePolicy) error { - - return cb(info, p1, p2) - } - - return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb) + return g.graph.ForEachNodeChannel(nodePub, cb) } // sourceNode returns the source node of the graph. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) sourceNode() route.Vertex { +func (g *CachedGraph) sourceNode() route.Vertex { return g.source } @@ -82,23 +66,8 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { // unknown, assume no additional features are supported. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( +func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { - targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub) - switch err { - - // If the node exists and has features, return them directly. - case nil: - return targetNode.Features, nil - - // If we couldn't find a node announcement, populate a blank feature - // vector. - case channeldb.ErrGraphNodeNotFound: - return lnwire.EmptyFeatureVector(), nil - - // Otherwise bubble the error up. - default: - return nil, err - } + return g.graph.FetchNodeFeatures(nodePub) } diff --git a/routing/heap.go b/routing/heap.go index f6869663c..36563bb66 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -39,7 +39,7 @@ type nodeWithDist struct { weight int64 // nextHop is the edge this route comes from. - nextHop *channeldb.ChannelEdgePolicy + nextHop *channeldb.CachedEdgePolicy // routingInfoSize is the total size requirement for the payloads field // in the onion packet from this hop towards the final destination. diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 114b2272e..d13b1c432 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, - func() (routingGraph, func(), error) { - return c.graph, func() {}, nil - }, - mc, c.pathFindingCfg, + &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 3834d9e51..6d0156666 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // // NOTE: Part of the routingGraph interface. func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { + cb func(channel *channeldb.DirectedChannel) error) error { // Look up the mock node. node, ok := m.nodes[nodePub] @@ -171,36 +170,31 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // Iterate over all of its channels. for peer, channel := range node.channels { // Lexicographically sort the pubkeys. - var node1, node2 route.Vertex + var node1 route.Vertex if bytes.Compare(nodePub[:], peer[:]) == -1 { - node1, node2 = peer, nodePub + node1 = peer } else { - node1, node2 = nodePub, peer + node1 = nodePub } peerNode := m.nodes[peer] // Call the per channel callback. err := cb( - &channeldb.ChannelEdgeInfo{ - NodeKey1Bytes: node1, - NodeKey2Bytes: node2, - }, - &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: peer, - Features: lnwire.EmptyFeatureVector(), + &channeldb.DirectedChannel{ + ChannelID: channel.id, + IsNode1: nodePub == node1, + OtherNode: peer, + Capacity: channel.capacity, + OutPolicySet: true, + InPolicy: &channeldb.CachedEdgePolicy{ + ChannelID: channel.id, + ToNodePubKey: func() route.Vertex { + return nodePub + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + FeeBaseMSat: peerNode.baseFee, }, - FeeBaseMSat: node.baseFee, - }, - &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: nodePub, - Features: lnwire.EmptyFeatureVector(), - }, - FeeBaseMSat: peerNode.baseFee, }, ) if err != nil { diff --git a/routing/mock_test.go b/routing/mock_test.go index 383f89185..a59ae2aa4 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -173,13 +173,13 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, } func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, - _ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool { + _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool { return false } func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey, - _ uint64) *channeldb.ChannelEdgePolicy { + _ uint64) *channeldb.CachedEdgePolicy { return nil } @@ -637,17 +637,17 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) return args.Bool(0) } func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { args := m.Called(pubKey, channelID) - return args.Get(0).(*channeldb.ChannelEdgePolicy) + return args.Get(0).(*channeldb.CachedEdgePolicy) } type mockControlTower struct { diff --git a/routing/pathfind.go b/routing/pathfind.go index 3d722c822..27a67ea7a 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -42,7 +42,7 @@ const ( type pathFinder = func(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( - []*channeldb.ChannelEdgePolicy, error) + []*channeldb.CachedEdgePolicy, error) var ( // DefaultAttemptCost is the default fixed virtual cost in path finding @@ -76,7 +76,7 @@ var ( // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex - edge *channeldb.ChannelEdgePolicy + edge *channeldb.CachedEdgePolicy } // finalHopParams encapsulates various parameters for route construction that @@ -102,7 +102,7 @@ type finalHopParams struct { // any feature vectors on all hops have been validated for transitive // dependencies. func newRoute(sourceVertex route.Vertex, - pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, + pathEdges []*channeldb.CachedEdgePolicy, currentHeight uint32, finalHop finalHopParams) (*route.Route, error) { var ( @@ -147,10 +147,10 @@ func newRoute(sourceVertex route.Vertex, supports := func(feature lnwire.FeatureBit) bool { // If this edge comes from router hints, the features // could be nil. - if edge.Node.Features == nil { + if edge.ToNodeFeatures == nil { return false } - return edge.Node.Features.HasFeature(feature) + return edge.ToNodeFeatures.HasFeature(feature) } // We start by assuming the node doesn't support TLV. We'll now @@ -225,7 +225,7 @@ func newRoute(sourceVertex route.Vertex, // each new hop such that, the final slice of hops will be in // the forwards order. currentHop := &route.Hop{ - PubKeyBytes: edge.Node.PubKeyBytes, + PubKeyBytes: edge.ToNodePubKey(), ChannelID: edge.ChannelID, AmtToForward: amtToForward, OutgoingTimeLock: outgoingTimeLock, @@ -280,7 +280,7 @@ type graphParams struct { // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the // channel graph. - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy // bandwidthHints is an optional map from channels to bandwidths that // can be populated if the caller has a better estimate of the current @@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, - _ *channeldb.ChannelEdgePolicy) error { - - if outEdge == nil { + cb := func(channel *channeldb.DirectedChannel) error { + if !channel.OutPolicySet { return nil } - chanID := outEdge.ChannelID + chanID := channel.ChannelID // Enforce outgoing channel restriction. if outgoingChans != nil { @@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // This can happen when a channel is added to the graph after // we've already queried the bandwidth hints. if !ok { - bandwidth = lnwire.NewMSatFromSatoshis( - edgeInfo.Capacity, - ) + bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity) } if bandwidth > max { @@ -416,7 +412,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -523,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. for _, outgoingEdgePolicy := range outgoingEdgePolicies { - toVertex := outgoingEdgePolicy.Node.PubKeyBytes + toVertex := outgoingEdgePolicy.ToNodePubKey() incomingEdgePolicy := &edgePolicyWithSource{ sourceNode: vertex, edge: outgoingEdgePolicy, @@ -587,7 +583,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // satisfy our specific requirements. processEdge := func(fromVertex route.Vertex, fromFeatures *lnwire.FeatureVector, - edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) { + edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) { edgesExpanded++ @@ -883,13 +879,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Use the distance map to unravel the forward path from source to // target. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy currentNode := source for { // Determine the next hop forward using the next map. currentNodeWithDist, ok := distance[currentNode] if !ok { - // If the node doesnt have a next hop it means we didn't find a path. + // If the node doesn't have a next hop it means we + // didn't find a path. return nil, errNoPathFound } @@ -897,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, pathEdges = append(pathEdges, currentNodeWithDist.nextHop) // Advance current node. - currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes + currentNode = currentNodeWithDist.nextHop.ToNodePubKey() // Check stop condition at the end of this loop. This prevents // breaking out too soon for self-payments that have target set @@ -918,7 +915,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within // findPath, and avoid using ChannelEdgePolicy altogether. - pathEdges[len(pathEdges)-1].Node.Features = features + pathEdges[len(pathEdges)-1].ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", distance[source].probability, len(pathEdges), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 67fecb321..426faa099 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" @@ -148,26 +149,36 @@ type testChan struct { // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // purposes. A callback which cleans up the created temporary directories is // also returned and intended to be executed after the test completes. -func makeTestGraph() (*channeldb.ChannelGraph, func(), error) { +func makeTestGraph() (*channeldb.ChannelGraph, kvdb.Backend, func(), error) { // First, create a temporary directory to be used for the duration of // this test. tempDirName, err := ioutil.TempDir("", "channeldb") if err != nil { - return nil, nil, err + return nil, nil, nil, err } - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) + // Next, create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr") if err != nil { - return nil, nil, err + return nil, nil, nil, err } cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) + backendCleanup() + _ = os.RemoveAll(tempDirName) } - return cdb.ChannelGraph(), cleanUp, nil + opts := channeldb.DefaultOptions() + graph, err := channeldb.NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, + ) + if err != nil { + cleanUp() + return nil, nil, nil, err + } + + return graph, backend, cleanUp, nil } // parseTestGraph returns a fully populated ChannelGraph given a path to a JSON @@ -197,7 +208,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) { testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, cleanUp, err := makeTestGraph() + graph, graphBackend, cleanUp, err := makeTestGraph() if err != nil { return nil, err } @@ -293,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) { } } + aliasForNode := func(node route.Vertex) string { + for alias, pubKey := range aliasMap { + if pubKey == node { + return alias + } + } + + return "" + } + // With all the vertexes inserted, we can now insert the edges into the // test graph. for _, edge := range g.Edges { @@ -342,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) { return nil, err } + channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags) + isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0 + targetNode := edgeInfo.NodeKey1Bytes + if isUpdate1 { + targetNode = edgeInfo.NodeKey2Bytes + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), - ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags), + ChannelFlags: channelFlags, ChannelID: edge.ChannelID, LastUpdate: testTime, TimeLockDelta: edge.Expiry, @@ -353,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) { MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + Node: &channeldb.LightningNode{ + Alias: aliasForNode(targetNode), + PubKeyBytes: targetNode, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -381,11 +413,12 @@ func parseTestGraph(path string) (*testGraphInstance, error) { } return &testGraphInstance{ - graph: graph, - cleanUp: cleanUp, - aliasMap: aliasMap, - privKeyMap: privKeyMap, - channelIDs: channelIDs, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanUp, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + channelIDs: channelIDs, }, nil } @@ -447,8 +480,9 @@ type testChannel struct { } type testGraphInstance struct { - graph *channeldb.ChannelGraph - cleanUp func() + graph *channeldb.ChannelGraph + graphBackend kvdb.Backend + cleanUp func() // aliasMap is a map from a node's alias to its public key. This type is // provided in order to allow easily look up from the human memorable alias @@ -482,7 +516,7 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, cleanUp, err := makeTestGraph() + graph, graphBackend, cleanUp, err := makeTestGraph() if err != nil { return nil, err } @@ -622,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( channelFlags |= lnwire.ChanUpdateDisabled } + node2Features := lnwire.EmptyFeatureVector() + if node2.testChannelPolicy != nil { + node2Features = node2.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -633,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node1.MaxHTLC, FeeBaseMSat: node1.FeeBaseMsat, FeeProportionalMillionths: node1.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node2.Alias, + PubKeyBytes: node2Vertex, + Features: node2Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -650,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( } channelFlags |= lnwire.ChanUpdateDirection + node1Features := lnwire.EmptyFeatureVector() + if node1.testChannelPolicy != nil { + node1Features = node1.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -661,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node2.MaxHTLC, FeeBaseMSat: node2.FeeBaseMsat, FeeProportionalMillionths: node2.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node1.Alias, + PubKeyBytes: node1Vertex, + Features: node1Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -671,10 +725,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( } return &testGraphInstance{ - graph: graph, - cleanUp: cleanUp, - aliasMap: aliasMap, - privKeyMap: privKeyMap, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanUp, + aliasMap: aliasMap, + privKeyMap: privKeyMap, }, nil } @@ -1044,20 +1099,23 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { // Create the channel edge going from songoku to doge and include it in // our map of additional edges. - songokuToDoge := &channeldb.ChannelEdgePolicy{ - Node: doge, + songokuToDoge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return doge.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), ChannelID: 1337, FeeBaseMSat: 1, FeeProportionalMillionths: 1000, TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*channeldb.ChannelEdgePolicy{ + additionalEdges := map[route.Vertex][]*channeldb.CachedEdgePolicy{ graph.aliasMap["songoku"]: {songokuToDoge}, } find := func(r *RestrictParams) ( - []*channeldb.ChannelEdgePolicy, error) { + []*channeldb.CachedEdgePolicy, error) { return dbFindPath( graph.graph, additionalEdges, nil, @@ -1124,14 +1182,13 @@ func TestNewRoute(t *testing.T) { createHop := func(baseFee lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, bandwidth lnwire.MilliSatoshi, - timeLockDelta uint16) *channeldb.ChannelEdgePolicy { + timeLockDelta uint16) *channeldb.CachedEdgePolicy { - return &channeldb.ChannelEdgePolicy{ - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + return &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector(nil, nil), FeeProportionalMillionths: feeRate, FeeBaseMSat: baseFee, TimeLockDelta: timeLockDelta, @@ -1144,7 +1201,7 @@ func TestNewRoute(t *testing.T) { // hops is the list of hops (the route) that gets passed into // the call to newRoute. - hops []*channeldb.ChannelEdgePolicy + hops []*channeldb.CachedEdgePolicy // paymentAmount is the amount that is send into the route // indicated by hops. @@ -1193,7 +1250,7 @@ func TestNewRoute(t *testing.T) { // For a single hop payment, no fees are expected to be paid. name: "single hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(100, 1000, 1000000, 10), }, expectedFees: []lnwire.MilliSatoshi{0}, @@ -1206,7 +1263,7 @@ func TestNewRoute(t *testing.T) { // a fee to receive the payment. name: "two hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1221,7 +1278,7 @@ func TestNewRoute(t *testing.T) { name: "two hop tlv onion feature", destFeatures: tlvFeatures, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1238,7 +1295,7 @@ func TestNewRoute(t *testing.T) { destFeatures: tlvPayAddrFeatures, paymentAddr: &testPaymentAddr, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1258,7 +1315,7 @@ func TestNewRoute(t *testing.T) { // gets rounded down to 1. name: "three hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10, 1000000, 10), createHop(0, 10, 1000000, 5), createHop(0, 10, 1000000, 3), @@ -1273,7 +1330,7 @@ func TestNewRoute(t *testing.T) { // because of the increase amount to forward. name: "three hop with fee carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), createHop(0, 10000, 1000000, 5), createHop(0, 10000, 1000000, 3), @@ -1288,7 +1345,7 @@ func TestNewRoute(t *testing.T) { // effect. name: "three hop with minimal fees for carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), // First hop charges 0.1% so the second hop fee @@ -1312,7 +1369,7 @@ func TestNewRoute(t *testing.T) { // custom feature vector. if testCase.destFeatures != nil { finalHop := testCase.hops[len(testCase.hops)-1] - finalHop.Node.Features = testCase.destFeatures + finalHop.ToNodeFeatures = testCase.destFeatures } assertRoute := func(t *testing.T, route *route.Route) { @@ -1539,7 +1596,7 @@ func TestDestTLVGraphFallback(t *testing.T) { } find := func(r *RestrictParams, - target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) { + target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( ctx.graph, nil, nil, @@ -2120,7 +2177,7 @@ func TestPathFindSpecExample(t *testing.T) { // Carol, so we set "B" as the source node so path finding starts from // Bob. bob := ctx.aliases["B"] - bobNode, err := ctx.graph.FetchLightningNode(nil, bob) + bobNode, err := ctx.graph.FetchLightningNode(bob) if err != nil { t.Fatalf("unable to find bob: %v", err) } @@ -2170,7 +2227,7 @@ func TestPathFindSpecExample(t *testing.T) { // Next, we'll set A as the source node so we can assert that we create // the proper route for any queries starting with Alice. alice := ctx.aliases["A"] - aliceNode, err := ctx.graph.FetchLightningNode(nil, alice) + aliceNode, err := ctx.graph.FetchLightningNode(alice) if err != nil { t.Fatalf("unable to find alice: %v", err) } @@ -2270,16 +2327,16 @@ func TestPathFindSpecExample(t *testing.T) { } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, - path []*channeldb.ChannelEdgePolicy, nodeAliases ...string) { + path []*channeldb.CachedEdgePolicy, nodeAliases ...string) { if len(path) != len(nodeAliases) { t.Fatal("number of hops and number of aliases do not match") } for i, hop := range path { - if hop.Node.PubKeyBytes != aliasMap[nodeAliases[i]] { + if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] { t.Fatalf("expected %v to be pos #%v in hop, instead "+ - "%v was", nodeAliases[i], i, hop.Node.Alias) + "%v was", nodeAliases[i], i, hop.ToNodePubKey()) } } } @@ -2930,7 +2987,7 @@ func (c *pathFindingTestContext) cleanup() { } func (c *pathFindingTestContext) findPath(target route.Vertex, - amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, + amt lnwire.MilliSatoshi) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( @@ -2939,7 +2996,9 @@ func (c *pathFindingTestContext) findPath(target route.Vertex, ) } -func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) { +func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy, + expected []uint64) { + if len(path) != len(expected) { c.t.Fatalf("expected path of length %v, but got %v", len(expected), len(path)) @@ -2956,28 +3015,22 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, // dbFindPath calls findPath after getting a db transaction from the database // graph. func dbFindPath(graph *channeldb.ChannelGraph, - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy, bandwidthHints map[uint64]lnwire.MilliSatoshi, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { - routingTx, err := newDbRoutingTx(graph) + routingGraph, err := NewCachedGraph(graph) if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() return findPath( &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: routingGraph, }, r, cfg, source, target, amt, finalHtlcExpiry, ) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index aa856e7b2..945a53466 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -898,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, var ( isAdditionalEdge bool - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy ) // Before we apply the channel update, we need to decide whether the diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 6a3517023..d233d8bde 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, Payer: payer, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, - QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(e.Capacity) + QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + return lnwire.NewMSatFromSatoshis(c.Capacity) }, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) diff --git a/routing/payment_session.go b/routing/payment_session.go index 22e88090b..bbf9b6f96 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -144,13 +144,13 @@ type PaymentSession interface { // a boolean to indicate whether the update has been applied without // error. UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, - policy *channeldb.ChannelEdgePolicy) bool + policy *channeldb.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query // the ephemeral channel edge policy for additional edges. Returns a nil // if nothing found. GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy + channelID uint64) *channeldb.CachedEdgePolicy } // paymentSession is used during an HTLC routings session to prune the local @@ -162,7 +162,7 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) @@ -172,7 +172,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + routingGraph routingGraph // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probabiity. @@ -193,7 +193,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), - getRoutingGraph func() (routingGraph, func(), error), + routingGraph routingGraph, missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { @@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment, getBandwidthHints: getBandwidthHints, payment: p, pathFinder: findPath, - getRoutingGraph: getRoutingGraph, + routingGraph: routingGraph, pathFindingConfig: pathFindingConfig, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, @@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) - // Get a routing graph. - routingGraph, cleanup, err := p.getRoutingGraph() - if err != nil { - return nil, err - } - - sourceVertex := routingGraph.sourceNode() + sourceVertex := p.routingGraph.sourceNode() // Find a route for the current amount. path, err := p.pathFinder( &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, - graph: routingGraph, + graph: p.routingGraph, }, restrictions, &p.pathFindingConfig, sourceVertex, p.payment.Target, maxAmt, finalHtlcExpiry, ) - // Close routing graph. - cleanup() - switch { case err == errNoPathFound: // Don't split if this is a legacy payment without mpp @@ -403,7 +394,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { // Validate the message signature. if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { @@ -428,7 +419,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, // ephemeral channel edge policy for additional edges. Returns a nil if nothing // found. func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { target := route.NewVertex(pubKey) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 8122ff711..d688f9814 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -17,14 +17,14 @@ var _ PaymentSessionSource = (*SessionSource)(nil) type SessionSource struct { // Graph is the channel graph that will be used to gather metrics from // and also to carry out path finding queries. - Graph *channeldb.ChannelGraph + Graph routingGraph // QueryBandwidth is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link // to be traversed. If the link isn't available, then a value of zero // should be returned. Otherwise, the current up to date knowledge of // the available bandwidth of the link should be returned. - QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // MissionControl is a shared memory of sorts that executions of payment // path finding use in order to remember which vertexes/edges were @@ -40,21 +40,6 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } -// getRoutingGraph returns a routing graph and a clean-up function for -// pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { - routingTx, err := newDbRoutingTx(m.Graph) - if err != nil { - return nil, nil, err - } - return routingTx, func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }, nil -} - // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -62,19 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - sourceNode, err := m.Graph.SourceNode() - if err != nil { - return nil, err - } - getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { - return generateBandwidthHints(sourceNode, m.QueryBandwidth) + return generateBandwidthHints( + m.Graph.sourceNode(), m.Graph, m.QueryBandwidth, + ) } session, err := newPaymentSession( - p, getBandwidthHints, m.getRoutingGraph, + p, getBandwidthHints, m.Graph, m.MissionControl, m.PathFindingConfig, ) if err != nil { @@ -96,9 +78,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { // RouteHintsToEdges converts a list of invoice route hints to an edge map that // can be passed into pathfinding. func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( - map[route.Vertex][]*channeldb.ChannelEdgePolicy, error) { + map[route.Vertex][]*channeldb.CachedEdgePolicy, error) { - edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy) + edges := make(map[route.Vertex][]*channeldb.CachedEdgePolicy) // Traverse through all of the available hop hints and include them in // our edges map, indexed by the public key of the channel's starting @@ -128,9 +110,12 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // Finally, create the channel edge from the hop hint // and add it to list of edges corresponding to the node // at the start of the channel. - edge := &channeldb.ChannelEdgePolicy{ - Node: endNode, - ChannelID: hopHint.ChannelID, + edge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return endNode.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + ChannelID: hopHint.ChannelID, FeeBaseMSat: lnwire.MilliSatoshi( hopHint.FeeBaseMSat, ), diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index edc4515b5..dae331f84 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) @@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) @@ -217,7 +213,7 @@ func TestRequestRoute(t *testing.T) { session.pathFinder = func( g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). @@ -225,13 +221,14 @@ func TestRequestRoute(t *testing.T) { t.Fatal("wrong cltv limit") } - path := []*channeldb.ChannelEdgePolicy{ + path := []*channeldb.CachedEdgePolicy{ { - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector( + nil, nil, + ), }, } diff --git a/routing/router.go b/routing/router.go index 6ebf86c19..dd8a375a2 100644 --- a/routing/router.go +++ b/routing/router.go @@ -339,7 +339,7 @@ type Config struct { // a value of zero should be returned. Otherwise, the current up to // date knowledge of the available bandwidth of the link should be // returned. - QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // NextPaymentID is a method that guarantees to return a new, unique ID // each time it is called. This is used by the router to generate a @@ -406,6 +406,10 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode + // cachedGraph is an instance of routingGraph that caches the source node as + // well as the channel graph itself in memory. + cachedGraph routingGraph + // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to // UpdateFilter. @@ -460,14 +464,17 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil) // channel graph is a subset of the UTXO set) set, then the router will proceed // to fully sync to the latest state of the UTXO set. func New(cfg Config) (*ChannelRouter, error) { - selfNode, err := cfg.Graph.SourceNode() if err != nil { return nil, err } r := &ChannelRouter{ - cfg: &cfg, + cfg: &cfg, + cachedGraph: &CachedGraph{ + graph: cfg.Graph, + source: selfNode.PubKeyBytes, + }, networkUpdates: make(chan *routingMsg), topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), @@ -1727,7 +1734,7 @@ type routingMsg struct { func (r *ChannelRouter) FindRoute(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { log.Debugf("Searching for path to %v, sending %v", target, amt) @@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -1752,22 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // execute our path finding algorithm. finalHtlcExpiry := currentHeight + int32(finalExpiry) - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() - path, err := findPath( &graphParams{ additionalEdges: routeHints, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: r.cachedGraph, }, restrictions, &r.cfg.PathFindingConfig, @@ -2505,8 +2501,10 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( // within the graph. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.LightningNode, error) { - return r.cfg.Graph.FetchLightningNode(nil, node) +func (r *ChannelRouter) FetchLightningNode( + node route.Vertex) (*channeldb.LightningNode, error) { + + return r.cfg.Graph.FetchLightningNode(node) } // ForEachNode is used to iterate over every node in router topology. @@ -2661,19 +2659,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { // these hints allows us to reduce the number of extraneous attempts as we can // skip channels that are inactive, or just don't have enough bandwidth to // carry the payment. -func generateBandwidthHints(sourceNode *channeldb.LightningNode, - queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) { +func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph, + queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( + map[uint64]lnwire.MilliSatoshi, error) { // First, we'll collect the set of outbound edges from the target // source node. - var localChans []*channeldb.ChannelEdgeInfo - err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx, - edgeInfo *channeldb.ChannelEdgeInfo, - _, _ *channeldb.ChannelEdgePolicy) error { - - localChans = append(localChans, edgeInfo) - return nil - }) + var localChans []*channeldb.DirectedChannel + err := graph.forEachNodeChannel( + sourceNode, func(channel *channeldb.DirectedChannel) error { + localChans = append(localChans, channel) + return nil + }, + ) if err != nil { return nil, err } @@ -2726,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -2756,18 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, runningAmt = *amt } - // Open a transaction to execute the graph queries in. - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() - // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes for i := len(hops) - 1; i >= 0; i-- { @@ -2786,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // known in the graph. u := newUnifiedPolicies(source, toNode, outgoingChans) - err := u.addGraphPolicies(routingTx) + err := u.addGraphPolicies(r.cachedGraph) if err != nil { return nil, err } @@ -2832,7 +2818,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // total amount, we make a forward pass. Because the amount may have // been increased in the backward pass, fees need to be recalculated and // amount ranges re-checked. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy receiverAmt := runningAmt for i, edge := range edges { policy := edge.getPolicy(receiverAmt, bandwidthHints) diff --git a/routing/router_test.go b/routing/router_test.go index 2633bd5ab..4b5dd505f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -125,17 +125,19 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, } mc, err := NewMissionControl( - graphInstance.graph.Database(), route.Vertex{}, - mcConfig, + graphInstance.graphBackend, route.Vertex{}, mcConfig, ) require.NoError(t, err, "failed to create missioncontrol") - sessionSource := &SessionSource{ - Graph: graphInstance.graph, - QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + cachedGraph, err := NewCachedGraph(graphInstance.graph) + require.NoError(t, err) - return lnwire.NewMSatFromSatoshis(e.Capacity) + sessionSource := &SessionSource{ + Graph: cachedGraph, + QueryBandwidth: func( + c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + + return lnwire.NewMSatFromSatoshis(c.Capacity) }, PathFindingConfig: pathFindingConfig, MissionControl: mc, @@ -159,7 +161,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + e *channeldb.DirectedChannel) lnwire.MilliSatoshi { return lnwire.NewMSatFromSatoshis(e.Capacity) }, @@ -188,7 +190,6 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, cleanUp := func() { ctx.router.Stop() - graphInstance.cleanUp() } return ctx, cleanUp @@ -197,17 +198,10 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, func createTestCtxSingleNode(t *testing.T, startingHeight uint32) (*testCtx, func()) { - var ( - graph *channeldb.ChannelGraph - sourceNode *channeldb.LightningNode - cleanup func() - err error - ) - - graph, cleanup, err = makeTestGraph() + graph, graphBackend, cleanup, err := makeTestGraph() require.NoError(t, err, "failed to make test graph") - sourceNode, err = createTestNode() + sourceNode, err := createTestNode() require.NoError(t, err, "failed to create test node") require.NoError(t, @@ -215,8 +209,9 @@ func createTestCtxSingleNode(t *testing.T, ) graphInstance := &testGraphInstance{ - graph: graph, - cleanUp: cleanup, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanup, } return createTestCtxFromGraphInstance( @@ -1401,6 +1396,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1417,6 +1415,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 @@ -1498,6 +1499,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1513,6 +1517,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 @@ -1577,7 +1584,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to find any routes: %v", err) } - copy1, err := ctx.graph.FetchLightningNode(nil, pub1) + copy1, err := ctx.graph.FetchLightningNode(pub1) if err != nil { t.Fatalf("unable to fetch node: %v", err) } @@ -1586,7 +1593,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("fetched node not equal to original") } - copy2, err := ctx.graph.FetchLightningNode(nil, pub2) + copy2, err := ctx.graph.FetchLightningNode(pub2) if err != nil { t.Fatalf("unable to fetch node: %v", err) } @@ -2474,8 +2481,8 @@ func TestFindPathFeeWeighting(t *testing.T) { if len(path) != 1 { t.Fatalf("expected path length of 1, instead was: %v", len(path)) } - if path[0].Node.Alias != "luoji" { - t.Fatalf("wrong node: %v", path[0].Node.Alias) + if path[0].ToNodePubKey() != ctx.aliases["luoji"] { + t.Fatalf("wrong node: %v", path[0].ToNodePubKey()) } } diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 0ff509382..fe7cc1ec4 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -40,7 +40,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex, // addPolicy adds a single channel policy. Capacity may be zero if unknown // (light clients). func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, - edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) { + edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) { localChan := fromNode == u.sourceNode @@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, - inEdge *channeldb.ChannelEdgePolicy) error { - + cb := func(channel *channeldb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have // come prior to the pivot node in the route. - if inEdge == nil { + if channel.InPolicy == nil { return nil } - // The node on the other end of this channel is the from node. - fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:]) - if err != nil { - return err - } - // Add this policy to the unified policies map. - u.addPolicy(fromNode, inEdge, edgeInfo.Capacity) + u.addPolicy( + channel.OtherNode, channel.InPolicy, channel.Capacity, + ) return nil } @@ -98,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { // unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedPolicy object. type unifiedPolicyEdge struct { - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy capacity btcutil.Amount } @@ -139,7 +133,7 @@ type unifiedPolicy struct { // specific amount to send. It differentiates between local and network // channels. func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { if u.localChan { return u.getPolicyLocal(amt, bandwidthHints) @@ -151,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, // getPolicyLocal returns the optimal policy to use for this local connection // given a specific amount to send. func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxBandwidth lnwire.MilliSatoshi ) @@ -206,10 +200,10 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, // a specific amount to send. The goal is to return a policy that maximizes the // probability of a successful forward in a non-strict forwarding context. func (u *unifiedPolicy) getPolicyNetwork( - amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxFee lnwire.MilliSatoshi maxTimelock uint16 ) diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go index e89a3cb12..ac915f99a 100644 --- a/routing/unified_policies_test.go +++ b/routing/unified_policies_test.go @@ -20,7 +20,7 @@ func TestUnifiedPolicies(t *testing.T) { u := newUnifiedPolicies(source, toNode, nil) // Add two channels between the pair of nodes. - p1 := channeldb.ChannelEdgePolicy{ + p1 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 100000, FeeBaseMSat: 30, TimeLockDelta: 60, @@ -28,7 +28,7 @@ func TestUnifiedPolicies(t *testing.T) { MaxHTLC: 500, MinHTLC: 100, } - p2 := channeldb.ChannelEdgePolicy{ + p2 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 190000, FeeBaseMSat: 10, TimeLockDelta: 40, @@ -39,7 +39,7 @@ func TestUnifiedPolicies(t *testing.T) { u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p2, 7) - checkPolicy := func(policy *channeldb.ChannelEdgePolicy, + checkPolicy := func(policy *channeldb.CachedEdgePolicy, feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, timeLockDelta uint16) { diff --git a/rpcserver.go b/rpcserver.go index 54d62c07b..7a3b80939 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3989,7 +3989,7 @@ func (r *rpcServer) createRPCClosedChannel( CloseInitiator: closeInitiator, } - reports, err := r.server.chanStateDB.FetchChannelReports( + reports, err := r.server.miscDB.FetchChannelReports( *r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint, ) switch err { @@ -5152,7 +5152,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context, PendingOnly: req.PendingOnly, Reversed: req.Reversed, } - invoiceSlice, err := r.server.chanStateDB.QueryInvoices(q) + invoiceSlice, err := r.server.miscDB.QueryInvoices(q) if err != nil { return nil, fmt.Errorf("unable to query invoices: %v", err) } @@ -5549,7 +5549,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // With the public key decoded, attempt to fetch the node corresponding // to this public key. If the node cannot be found, then an error will // be returned. - node, err := graph.FetchLightningNode(nil, pubKey) + node, err := graph.FetchLightningNode(pubKey) switch { case err == channeldb.ErrGraphNodeNotFound: return nil, status.Error(codes.NotFound, err.Error()) @@ -5954,7 +5954,7 @@ func (r *rpcServer) ListPayments(ctx context.Context, query.MaxPayments = math.MaxUint64 } - paymentsQuerySlice, err := r.server.chanStateDB.QueryPayments(query) + paymentsQuerySlice, err := r.server.miscDB.QueryPayments(query) if err != nil { return nil, err } @@ -5995,9 +5995,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context, rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) - err = r.server.chanStateDB.DeletePayment( - hash, req.FailedHtlcsOnly, - ) + err = r.server.miscDB.DeletePayment(hash, req.FailedHtlcsOnly) if err != nil { return nil, err } @@ -6014,7 +6012,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, "failed_htlcs_only=%v", req.FailedPaymentsOnly, req.FailedHtlcsOnly) - err := r.server.chanStateDB.DeletePayments( + err := r.server.miscDB.DeletePayments( req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { @@ -6176,7 +6174,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, return nil, err } - fwdEventLog := r.server.chanStateDB.ForwardingLog() + fwdEventLog := r.server.miscDB.ForwardingLog() // computeFeeSum is a helper function that computes the total fees for // a particular time slice described by a forwarding event query. @@ -6417,7 +6415,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context, IndexOffset: req.IndexOffset, NumMaxEvents: numEvents, } - timeSlice, err := r.server.chanStateDB.ForwardingLog().Query(eventQuery) + timeSlice, err := r.server.miscDB.ForwardingLog().Query(eventQuery) if err != nil { return nil, fmt.Errorf("unable to query forwarding log: %v", err) } @@ -6479,7 +6477,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context, // the database. If this channel has been closed, or the outpoint is // unknown, then we'll return an error unpackedBackup, err := chanbackup.FetchBackupForChan( - chanPoint, r.server.chanStateDB, + chanPoint, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, err @@ -6649,7 +6647,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context, // First, we'll attempt to read back ups for ALL currently opened // channels from disk. allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, + r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, fmt.Errorf("unable to fetch all static chan "+ @@ -6776,7 +6774,7 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription // we'll obtains the current set of single channel // backups from disk. chanBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, + r.server.chanStateDB, r.server.addrSource, ) if err != nil { return fmt.Errorf("unable to fetch all "+ diff --git a/server.go b/server.go index d80f5bb79..2d225d4c5 100644 --- a/server.go +++ b/server.go @@ -222,7 +222,13 @@ type server struct { graphDB *channeldb.ChannelGraph - chanStateDB *channeldb.DB + chanStateDB *channeldb.ChannelStateDB + + addrSource chanbackup.AddressSource + + // miscDB is the DB that contains all "other" databases within the main + // channel DB that haven't been separated out yet. + miscDB *channeldb.DB htlcSwitch *htlcswitch.Switch @@ -432,14 +438,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, graphDB: dbs.graphDB.ChannelGraph(), - chanStateDB: dbs.chanStateDB, + chanStateDB: dbs.chanStateDB.ChannelStateDB(), + addrSource: dbs.chanStateDB, + miscDB: dbs.chanStateDB, cc: cc, sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), writePool: writePool, readPool: readPool, chansToRestore: chansToRestore, - channelNotifier: channelnotifier.New(dbs.chanStateDB), + channelNotifier: channelnotifier.New( + dbs.chanStateDB.ChannelStateDB(), + ), identityECDH: nodeKeyECDH, nodeSigner: netann.NewNodeSigner(nodeKeySigner), @@ -494,7 +504,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats) s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ - DB: dbs.chanStateDB, + DB: dbs.chanStateDB, + FetchAllOpenChannels: s.chanStateDB.FetchAllOpenChannels, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, LocalChannelClose: func(pubKey []byte, request *htlcswitch.ChanClose) { @@ -537,7 +549,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MessageSigner: s.nodeSigner, IsChannelActive: s.htlcSwitch.HasActiveLink, ApplyChannelUpdate: s.applyChannelUpdate, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Graph: dbs.graphDB.ChannelGraph(), } @@ -702,9 +714,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } - queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) - link, err := s.htlcSwitch.GetLink(cid) + queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + cid := lnwire.NewShortChanIDFromInt(c.ChannelID) + link, err := s.htlcSwitch.GetLinkByShortID(cid) if err != nil { // If the link isn't online, then we'll report // that it has zero bandwidth to the router. @@ -768,8 +780,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MinProbability: routingConfig.MinRouteProbability, } + cachedGraph, err := routing.NewCachedGraph(chanGraph) + if err != nil { + return nil, err + } paymentSessionSource := &routing.SessionSource{ - Graph: chanGraph, + Graph: cachedGraph, MissionControl: s.missionControl, QueryBandwidth: queryBandwidth, PathFindingConfig: pathFindingConfig, @@ -805,11 +821,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } chanSeries := discovery.NewChanSeries(s.graphDB) - gossipMessageStore, err := discovery.NewMessageStore(s.chanStateDB) + gossipMessageStore, err := discovery.NewMessageStore(dbs.chanStateDB) if err != nil { return nil, err } - waitingProofStore, err := channeldb.NewWaitingProofStore(s.chanStateDB) + waitingProofStore, err := channeldb.NewWaitingProofStore(dbs.chanStateDB) if err != nil { return nil, err } @@ -891,8 +907,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ ChainIO: cc.ChainIO, ConfDepth: 1, - FetchClosedChannels: dbs.chanStateDB.FetchClosedChannels, - FetchClosedChannel: dbs.chanStateDB.FetchClosedChannel, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, + FetchClosedChannel: s.chanStateDB.FetchClosedChannel, Notifier: cc.ChainNotifier, PublishTransaction: cc.Wallet.PublishTransaction, Store: utxnStore, @@ -1018,7 +1034,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{ CloseLink: closeLink, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Estimator: s.cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), Notifier: cc.ChainNotifier, @@ -1075,7 +1091,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, FindChannel: func(chanID lnwire.ChannelID) ( *channeldb.OpenChannel, error) { - dbChannels, err := dbs.chanStateDB.FetchAllChannels() + dbChannels, err := s.chanStateDB.FetchAllChannels() if err != nil { return nil, err } @@ -1247,10 +1263,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // static backup of the latest channel state. chanNotifier := &channelNotifier{ chanNotifier: s.channelNotifier, - addrs: s.chanStateDB, + addrs: dbs.chanStateDB, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) - startingChans, err := chanbackup.FetchStaticChanBackups(s.chanStateDB) + startingChans, err := chanbackup.FetchStaticChanBackups( + s.chanStateDB, s.addrSource, + ) if err != nil { return nil, err } @@ -1275,8 +1293,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, GetOpenChannels: s.chanStateDB.FetchAllOpenChannels, Clock: clock.NewDefaultClock(), - ReadFlapCount: s.chanStateDB.ReadFlapCount, - WriteFlapCount: s.chanStateDB.WriteFlapCounts, + ReadFlapCount: s.miscDB.ReadFlapCount, + WriteFlapCount: s.miscDB.WriteFlapCounts, FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate), }) @@ -2531,7 +2549,7 @@ func (s *server) establishPersistentConnections() error { // Iterate through the list of LinkNodes to find addresses we should // attempt to connect to based on our set of previous connections. Set // the reconnection port to the default peer port. - linkNodes, err := s.chanStateDB.FetchAllLinkNodes() + linkNodes, err := s.chanStateDB.LinkNodeDB().FetchAllLinkNodes() if err != nil && err != channeldb.ErrLinkNodesNotFound { return err } @@ -3911,7 +3929,7 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error) return nil, err } - node, err := s.graphDB.FetchLightningNode(nil, vertex) + node, err := s.graphDB.FetchLightningNode(vertex) if err != nil { return nil, err } diff --git a/subrpcserver_config.go b/subrpcserver_config.go index bf5911ec2..04853db76 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -93,7 +93,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, routerBackend *routerrpc.RouterBackend, nodeSigner *netann.NodeSigner, graphDB *channeldb.ChannelGraph, - chanStateDB *channeldb.DB, + chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, towerClient wtclient.Client,