From c5c2fc27f90f9a9de853fc090fc4ce13aa48bfec Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 5 Feb 2020 15:39:31 +0200 Subject: [PATCH 1/2] channeldb/test: replace test channel boilerplate createTestChannel This change replaces test channel creation boilerplate with a createTestChannel function which can be customized using functional options. --- channeldb/channel_test.go | 346 ++++++++++++++++++++++---------------- channeldb/db_test.go | 59 ++----- 2 files changed, 210 insertions(+), 195 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 891f0d824..927c380c3 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -71,6 +71,17 @@ var ( wireSig, _ = lnwire.NewSigFromSignature(testSig) testClock = clock.NewTestClock(testNow) + + // defaultPendingHeight is the default height at which we set + // channels to pending. + defaultPendingHeight = 100 + + // defaultAddr is the default address that we mark test channels pending + // with. + defaultAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } ) // makeTestDB creates a new instance of the ChannelDB for testing purposes. A @@ -98,21 +109,141 @@ func makeTestDB() (*DB, func(), error) { return cdb, cleanUp, nil } -func createTestChannelState(cdb *DB) (*OpenChannel, error) { +// testChannelParams is a struct which details the specifics of how a channel +// should be created. +type testChannelParams struct { + // channel is the channel that will be written to disk. + channel *OpenChannel + + // addr is the address that the channel will be synced pending with. + addr *net.TCPAddr + + // pendingHeight is the height that the channel should be recorded as + // pending. + pendingHeight uint32 + + // openChannel is set to true if the channel should be fully marked as + // open if this is false, the channel will be left in pending state. + openChannel bool +} + +// testChannelOption is a functional option which can be used to alter the +// default channel that is creates for testing. +type testChannelOption func(params *testChannelParams) + +// pendingHeightOption is an option which can be used to set the height the +// channel is marked as pending at. +func pendingHeightOption(height uint32) testChannelOption { + return func(params *testChannelParams) { + params.pendingHeight = height + } +} + +// openChannelOption is an option which can be used to create a test channel +// that is open. +func openChannelOption() testChannelOption { + return func(params *testChannelParams) { + params.openChannel = true + } +} + +// localHtlcsOption is an option which allows setting of htlcs on the local +// commitment. +func localHtlcsOption(htlcs []HTLC) testChannelOption { + return func(params *testChannelParams) { + params.channel.LocalCommitment.Htlcs = htlcs + } +} + +// remoteHtlcsOption is an option which allows setting of htlcs on the remote +// commitment. +func remoteHtlcsOption(htlcs []HTLC) testChannelOption { + return func(params *testChannelParams) { + params.channel.RemoteCommitment.Htlcs = htlcs + } +} + +// localShutdownOption is an option which sets the local upfront shutdown +// script for the channel. +func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { + return func(params *testChannelParams) { + params.channel.LocalShutdownScript = addr + } +} + +// remoteShutdownOption is an option which sets the remote upfront shutdown +// script for the channel. +func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { + return func(params *testChannelParams) { + params.channel.RemoteShutdownScript = addr + } +} + +// fundingPointOption is an option which sets the funding outpoint of the +// channel. +func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { + return func(params *testChannelParams) { + params.channel.FundingOutpoint = chanPoint + } +} + +// createTestChannel writes a test channel to the database. It takes a set of +// functional options which can be used to overwrite the default of creating +// a pending channel that was broadcast at height 100. +func createTestChannel(t *testing.T, cdb *DB, + opts ...testChannelOption) *OpenChannel { + + // Create a default set of parameters. + params := &testChannelParams{ + channel: createTestChannelState(t, cdb), + addr: defaultAddr, + openChannel: false, + pendingHeight: uint32(defaultPendingHeight), + } + + // Apply all functional options to the test channel params. + for _, o := range opts { + o(params) + } + + // Mark the channel as pending. + err := params.channel.SyncPending(params.addr, params.pendingHeight) + if err != nil { + t.Fatalf("unable to save and serialize channel "+ + "state: %v", err) + } + + // If the parameters do not specify that we should open the channel + // fully, we return the pending channel. + if !params.openChannel { + return params.channel + } + + // Mark the channel as open with the short channel id provided. + err = params.channel.MarkAsOpen(params.channel.ShortChannelID) + if err != nil { + t.Fatalf("unable to mark channel open: %v", err) + } + + return params.channel +} + +func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { // Simulate 1000 channel updates. producer, err := shachain.NewRevocationProducerFromBytes(key[:]) if err != nil { - return nil, err + t.Fatalf("could not get producer: %v", err) } store := shachain.NewRevocationStore() for i := 0; i < 1; i++ { preImage, err := producer.AtIndex(uint64(i)) if err != nil { - return nil, err + t.Fatalf("could not get "+ + "preimage: %v", err) } if err := store.AddNextEntry(preImage); err != nil { - return nil, err + t.Fatalf("could not add entry: %v", err) } } @@ -228,7 +359,7 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) { Db: cdb, Packager: NewChannelPackager(chanID), FundingTxn: testTx, - }, nil + } } func TestOpenChannelPutGetDelete(t *testing.T) { @@ -240,15 +371,10 @@ func TestOpenChannelPutGetDelete(t *testing.T) { } defer cleanUp() - // Create the test channel state, then add an additional fake HTLC - // before syncing to disk. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - state.LocalCommitment.Htlcs = []HTLC{ - { - Signature: testSig.Serialize(), + // Create the test channel state, with additional htlcs on the local + // and remote commitment. + localHtlcs := []HTLC{ + {Signature: testSig.Serialize(), Incoming: true, Amt: 10, RHash: key, @@ -256,7 +382,8 @@ func TestOpenChannelPutGetDelete(t *testing.T) { OnionBlob: []byte("onionblob"), }, } - state.RemoteCommitment.Htlcs = []HTLC{ + + remoteHtlcs := []HTLC{ { Signature: testSig.Serialize(), Incoming: false, @@ -267,13 +394,11 @@ func TestOpenChannelPutGetDelete(t *testing.T) { }, } - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } + state := createTestChannel( + t, cdb, + remoteHtlcsOption(remoteHtlcs), + localHtlcsOption(localHtlcs), + ) openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) if err != nil { @@ -360,36 +485,28 @@ func TestOptionalShutdown(t *testing.T) { tests := []struct { name string - modifyChannel func(channel *OpenChannel) - expectedLocal lnwire.DeliveryAddress - expectedRemote lnwire.DeliveryAddress + localShutdown lnwire.DeliveryAddress + remoteShutdown lnwire.DeliveryAddress }{ { - name: "no shutdown scripts", - modifyChannel: func(channel *OpenChannel) {}, + name: "no shutdown scripts", + localShutdown: nil, + remoteShutdown: nil, }, { - name: "local shutdown script", - modifyChannel: func(channel *OpenChannel) { - channel.LocalShutdownScript = local - }, - expectedLocal: local, + name: "local shutdown script", + localShutdown: local, + remoteShutdown: nil, }, { - name: "remote shutdown script", - modifyChannel: func(channel *OpenChannel) { - channel.RemoteShutdownScript = remote - }, - expectedRemote: remote, + name: "remote shutdown script", + localShutdown: nil, + remoteShutdown: remote, }, { - name: "both scripts set", - modifyChannel: func(channel *OpenChannel) { - channel.LocalShutdownScript = local - channel.RemoteShutdownScript = remote - }, - expectedLocal: local, - expectedRemote: remote, + name: "both scripts set", + localShutdown: local, + remoteShutdown: remote, }, } @@ -403,40 +520,40 @@ func TestOptionalShutdown(t *testing.T) { } defer cleanUp() - // Create the test channel state, then add an additional fake HTLC - // before syncing to disk. - state, err := createTestChannelState(cdb) + // Create a channel with upfront scripts set as + // specified in the test. + state := createTestChannel( + t, cdb, + localShutdownOption(test.localShutdown), + remoteShutdownOption(test.remoteShutdown), + ) + + openChannels, err := cdb.FetchOpenChannels( + state.IdentityPub, + ) if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - test.modifyChannel(state) - - // Write channels to Db. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channel: %v", err) + t.Fatalf("unable to fetch open"+ + " channel: %v", err) } if len(openChannels) != 1 { - t.Fatalf("Expected one channel open, got: %v", len(openChannels)) + t.Fatalf("Expected one channel open,"+ + " got: %v", len(openChannels)) } - if !bytes.Equal(openChannels[0].LocalShutdownScript, test.expectedLocal) { - t.Fatalf("Expected local: %x, got: %x", test.expectedLocal, + if !bytes.Equal(openChannels[0].LocalShutdownScript, + test.localShutdown) { + + t.Fatalf("Expected local: %x, got: %x", + test.localShutdown, openChannels[0].LocalShutdownScript) } - if !bytes.Equal(openChannels[0].RemoteShutdownScript, test.expectedRemote) { - t.Fatalf("Expected remote: %x, got: %x", test.expectedRemote, + if !bytes.Equal(openChannels[0].RemoteShutdownScript, + test.remoteShutdown) { + + t.Fatalf("Expected remote: %x, got: %x", + test.remoteShutdown, openChannels[0].RemoteShutdownScript) } }) @@ -462,18 +579,7 @@ func TestChannelStateTransition(t *testing.T) { // First create a minimal channel, then perform a full sync in order to // persist the data. - channel, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := channel.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } + channel := createTestChannel(t, cdb) // Add some HTLCs which were added during this new state transition. // Half of the HTLCs are incoming, while the other half are outgoing. @@ -776,21 +882,9 @@ func TestFetchPendingChannels(t *testing.T) { } defer cleanUp() - // Create first test channel state - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - + // Create a pending channel that was broadcast at height 99. const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } + createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) pendingChannels, err := cdb.FetchPendingChannels() if err != nil { @@ -867,35 +961,8 @@ func TestFetchClosedChannels(t *testing.T) { } defer cleanUp() - // First create a test channel, that we'll be closing within this pull - // request. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Next sync the channel to disk, marking it as being in a pending open - // state. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Next, simulate the confirmation of the channel by marking it as - // pending within the database. - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 5, - TxIndex: 10, - TxPosition: 15, - } - err = state.MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } + // Create an open channel in the database. + state := createTestChannel(t, cdb, openChannelOption()) // Next, close the channel by including a close channel summary in the // database. @@ -975,7 +1042,6 @@ func TestFetchWaitingCloseChannels(t *testing.T) { const numChannels = 2 const broadcastHeight = 99 - addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555} // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while @@ -988,15 +1054,11 @@ func TestFetchWaitingCloseChannels(t *testing.T) { channels := make([]*OpenChannel, numChannels) for i := 0; i < numChannels; i++ { - channel, err := createTestChannelState(db) - if err != nil { - t.Fatalf("unable to create channel: %v", err) - } - err = channel.SyncPending(addr, broadcastHeight) - if err != nil { - t.Fatalf("unable to sync channel: %v", err) - } - channels[i] = channel + // Create a pending channel in the database at the broadcast + // height. + channels[i] = createTestChannel( + t, db, pendingHeightOption(broadcastHeight), + ) } // We'll only confirm the first one. @@ -1106,21 +1168,7 @@ func TestRefreshShortChanID(t *testing.T) { defer cleanUp() // First create a test channel. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - - // Mark the channel as pending within the channeldb. - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } + state := createTestChannel(t, cdb) // Next, locate the pending channel with the database. pendingChannels, err := cdb.FetchPendingChannels() diff --git a/channeldb/db_test.go b/channeldb/db_test.go index b9b1189be..c1f525077 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -100,10 +100,7 @@ func TestFetchClosedChannelForID(t *testing.T) { // Create the test channel state, that we will mutate the index of the // funding point. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } + state := createTestChannelState(t, cdb) // Now run through the number of channels, and modify the outpoint index // to create new channel IDs. @@ -111,14 +108,12 @@ func TestFetchClosedChannelForID(t *testing.T) { // Save the open channel to disk. state.FundingOutpoint.Index = i - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel "+ - "state: %v", err) - } + // Write the channel to disk in a pending state. + createTestChannel( + t, cdb, + fundingPointOption(state.FundingOutpoint), + openChannelOption(), + ) // Close the channel. To make sure we retrieve the correct // summary later, we make them differ in the SettledBalance. @@ -235,26 +230,8 @@ func TestFetchChannel(t *testing.T) { } defer cleanUp() - // Create the test channel state that we'll sync to the database - // shortly. - channelState, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Mark the channel as pending, then immediately mark it as open to it - // can be fully visible. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - if err := channelState.SyncPending(addr, 9); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99)) - if err != nil { - t.Fatalf("unable to mark channel open: %v", err) - } + // Create an open channel. + channelState := createTestChannel(t, cdb, openChannelOption()) // Next, attempt to fetch the channel by its chan point. dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) @@ -271,7 +248,7 @@ func TestFetchChannel(t *testing.T) { // If we attempt to query for a non-exist ante channel, then we should // get an error. - channelState2, err := createTestChannelState(cdb) + channelState2 := createTestChannelState(t, cdb) if err != nil { t.Fatalf("unable to create channel state: %v", err) } @@ -491,19 +468,9 @@ func TestAbandonChannel(t *testing.T) { t.Fatalf("removing non-existent channel should have failed") } - // We'll now create a new channel to abandon shortly. - chanState, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - err = chanState.SyncPending(addr, 10) - if err != nil { - t.Fatalf("unable to sync pending channel: %v", err) - } + // We'll now create a new channel in a pending state to abandon + // shortly. + chanState := createTestChannel(t, cdb) // We should now be able to abandon the channel without any errors. closeHeight := uint32(11) From ed81c882395af051e3d4e781c811ab2b0e085a18 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 6 Feb 2020 10:21:12 +0200 Subject: [PATCH 2/2] channeldb: replace fetch channels booleans with optional filters This changes replaces the pending an waiting booleans in fetchChannels with optional filters which can be more flexibly used. This change allows filtering of channels without having to reason about the matrix of possible boolean combinations. A test is added to ensure that the combinations of these filters act as expected. --- channeldb/channel_test.go | 7 ++ channeldb/db.go | 121 +++++++++++++----------- channeldb/db_test.go | 193 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 54 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 927c380c3..cb29b5213 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -187,6 +187,13 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { } } +// channelIDOption is an option which sets the short channel ID of the channel. +var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { + return func(params *testChannelParams) { + params.channel.ShortChannelID = chanID + } +} + // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. diff --git a/channeldb/db.go b/channeldb/db.go index c9410bfbb..696f3f008 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -556,42 +556,28 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - var channels []*OpenChannel - - // TODO(halseth): fetch all in one db tx. - openChannels, err := d.FetchAllOpenChannels() - if err != nil { - return nil, err - } - channels = append(channels, openChannels...) - - pendingChannels, err := d.FetchPendingChannels() - if err != nil { - return nil, err - } - channels = append(channels, pendingChannels...) - - waitingClose, err := d.FetchWaitingCloseChannels() - if err != nil { - return nil, err - } - channels = append(channels, waitingClose...) - - return channels, nil + return fetchChannels(d) } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { - return fetchChannels(d, false, false) + return fetchChannels( + d, + pendingChannelFilter(false), + waitingCloseFilter(false), + ) } // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, true, false) + return fetchChannels(d, + pendingChannelFilter(true), + waitingCloseFilter(false), + ) } // FetchWaitingCloseChannels will return all channels that have been opened, @@ -599,25 +585,49 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { // // NOTE: This includes channels that are also pending to be opened. func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { - waitingClose, err := fetchChannels(d, false, true) - if err != nil { - return nil, err - } - pendingWaitingClose, err := fetchChannels(d, true, true) - if err != nil { - return nil, err - } + return fetchChannels( + d, waitingCloseFilter(true), + ) +} - return append(waitingClose, pendingWaitingClose...), nil +// fetchChannelsFilter applies a filter to channels retrieved in fetchchannels. +// A set of filters can be combined to filter across multiple dimensions. +type fetchChannelsFilter func(channel *OpenChannel) bool + +// pendingChannelFilter returns a filter based on whether channels are pending +// (ie, their funding transaction still needs to confirm). If pending is false, +// channels with confirmed funding transactions are returned. +func pendingChannelFilter(pending bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + return channel.IsPending == pending + } +} + +// waitingCloseFilter returns a filter which filters channels based on whether +// they are awaiting the confirmation of their closing transaction. If waiting +// close is true, channels that have had their closing tx broadcast are +// included. If it is false, channels that are not awaiting confirmation of +// their close transaction are returned. +func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + // If the channel is in any other state than Default, + // then it means it is waiting to be closed. + channelWaitingClose := + channel.ChanStatus() != ChanStatusDefault + + // Include the channel if it matches the value for + // waiting close that we are filtering on. + return channelWaitingClose == waitingClose + } } // fetchChannels attempts to retrieve channels currently stored in the -// database. The pending parameter determines whether only pending channels -// will be returned, or only open channels will be returned. The waitingClose -// parameter determines whether only channels waiting for a closing transaction -// to be confirmed should be returned. If no active channels exist within the -// network, then ErrNoActiveChannels is returned. -func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { +// database. It takes a set of filters which are applied to each channel to +// obtain a set of channels with the desired set of properties. Only channels +// which have a true value returned for *all* of the filters will be returned. +// If no filters are provided, every channel in the open channels bucket will +// be returned. +func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { var channels []*OpenChannel err := d.View(func(tx *bbolt.Tx) error { @@ -667,24 +677,27 @@ func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { "node_key=%x: %v", chainHash[:], k, err) } for _, channel := range nodeChans { - if channel.IsPending != pending { - continue + // includeChannel indicates whether the channel + // meets the criteria specified by our filters. + includeChannel := true + + // Run through each filter and check whether the + // channel should be included. + for _, f := range filters { + // If the channel fails the filter, set + // includeChannel to false and don't bother + // checking the remaining filters. + if !f(channel) { + includeChannel = false + break + } } - // If the channel is in any other state - // than Default, then it means it is - // waiting to be closed. - channelWaitingClose := - channel.ChanStatus() != ChanStatusDefault - - // Only include it if we requested - // channels with the same waitingClose - // status. - if channelWaitingClose != waitingClose { - continue + // If the channel passed every filter, include it in + // our set of channels. + if includeChannel { + channels = append(channels, channel) } - - channels = append(channels, channel) } return nil }) diff --git a/channeldb/db_test.go b/channeldb/db_test.go index c1f525077..935a287ce 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -500,3 +500,196 @@ func TestAbandonChannel(t *testing.T) { t.Fatalf("unable to abandon channel: %v", err) } } + +// TestFetchChannels tests the filtering of open channels in fetchChannels. +// It tests the case where no filters are provided (which is equivalent to +// FetchAllOpenChannels) and every combination of pending and waiting close. +func TestFetchChannels(t *testing.T) { + // Create static channel IDs for each kind of channel retrieved by + // fetchChannels so that the expected channel IDs can be set in tests. + var ( + // Pending is a channel that is pending open, and has not had + // a close initiated. + pendingChan = lnwire.NewShortChanIDFromInt(1) + + // pendingWaitingClose is a channel that is pending open and + // has has its closing transaction broadcast. + pendingWaitingChan = lnwire.NewShortChanIDFromInt(2) + + // openChan is a channel that has confirmed on chain. + openChan = lnwire.NewShortChanIDFromInt(3) + + // openWaitingChan is a channel that has confirmed on chain, + // and it waiting for its close transaction to confirm. + openWaitingChan = lnwire.NewShortChanIDFromInt(4) + ) + + tests := []struct { + name string + filters []fetchChannelsFilter + expectedChannels map[lnwire.ShortChannelID]bool + }{ + { + name: "get all channels", + filters: []fetchChannelsFilter{}, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + pendingWaitingChan: true, + openChan: true, + openWaitingChan: true, + }, + }, + { + name: "pending channels", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + pendingWaitingChan: true, + }, + }, + { + name: "open channels", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openChan: true, + openWaitingChan: true, + }, + }, + { + name: "waiting close channels", + filters: []fetchChannelsFilter{ + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingWaitingChan: true, + openWaitingChan: true, + }, + }, + { + name: "not waiting close channels", + filters: []fetchChannelsFilter{ + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + openChan: true, + }, + }, + { + name: "pending waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingWaitingChan: true, + }, + }, + { + name: "pending, not waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + }, + }, + { + name: "open waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openWaitingChan: true, + }, + }, + { + name: "open, not waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openChan: true, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test "+ + "database: %v", err) + } + defer cleanUp() + + // Create a pending channel that is not awaiting close. + createTestChannel( + t, cdb, channelIDOption(pendingChan), + ) + + // Create a pending channel which has has been marked as + // broadcast, indicating that its closing transaction is + // waiting to confirm. + pendingClosing := createTestChannel( + t, cdb, + channelIDOption(pendingWaitingChan), + ) + + err = pendingClosing.MarkCoopBroadcasted(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Create a open channel that is not awaiting close. + createTestChannel( + t, cdb, + channelIDOption(openChan), + openChannelOption(), + ) + + // Create a open channel which has has been marked as + // broadcast, indicating that its closing transaction is + // waiting to confirm. + openClosing := createTestChannel( + t, cdb, + channelIDOption(openWaitingChan), + openChannelOption(), + ) + err = openClosing.MarkCoopBroadcasted(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + channels, err := fetchChannels(cdb, test.filters...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(channels) != len(test.expectedChannels) { + t.Fatalf("expected: %v channels, "+ + "got: %v", len(test.expectedChannels), + len(channels)) + } + + for _, ch := range channels { + _, ok := test.expectedChannels[ch.ShortChannelID] + if !ok { + t.Fatalf("fetch channels unexpected "+ + "channel: %v", ch.ShortChannelID) + } + } + }) + } +}