Merge pull request #3984 from carlaKC/channeldb-fetchchannels

channeldb: Replace fetchChannels pending/waiting bools with optional filters
This commit is contained in:
Carla Kirk-Cohen
2020-02-11 08:40:54 +02:00
committed by GitHub
3 changed files with 477 additions and 249 deletions

View File

@@ -71,6 +71,17 @@ var (
wireSig, _ = lnwire.NewSigFromSignature(testSig) wireSig, _ = lnwire.NewSigFromSignature(testSig)
testClock = clock.NewTestClock(testNow) testClock = clock.NewTestClock(testNow)
// defaultPendingHeight is the default height at which we set
// channels to pending.
defaultPendingHeight = 100
// defaultAddr is the default address that we mark test channels pending
// with.
defaultAddr = &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
) )
// makeTestDB creates a new instance of the ChannelDB for testing purposes. A // makeTestDB creates a new instance of the ChannelDB for testing purposes. A
@@ -98,21 +109,148 @@ func makeTestDB() (*DB, func(), error) {
return cdb, cleanUp, nil return cdb, cleanUp, nil
} }
func createTestChannelState(cdb *DB) (*OpenChannel, error) { // testChannelParams is a struct which details the specifics of how a channel
// should be created.
type testChannelParams struct {
// channel is the channel that will be written to disk.
channel *OpenChannel
// addr is the address that the channel will be synced pending with.
addr *net.TCPAddr
// pendingHeight is the height that the channel should be recorded as
// pending.
pendingHeight uint32
// openChannel is set to true if the channel should be fully marked as
// open if this is false, the channel will be left in pending state.
openChannel bool
}
// testChannelOption is a functional option which can be used to alter the
// default channel that is creates for testing.
type testChannelOption func(params *testChannelParams)
// pendingHeightOption is an option which can be used to set the height the
// channel is marked as pending at.
func pendingHeightOption(height uint32) testChannelOption {
return func(params *testChannelParams) {
params.pendingHeight = height
}
}
// openChannelOption is an option which can be used to create a test channel
// that is open.
func openChannelOption() testChannelOption {
return func(params *testChannelParams) {
params.openChannel = true
}
}
// localHtlcsOption is an option which allows setting of htlcs on the local
// commitment.
func localHtlcsOption(htlcs []HTLC) testChannelOption {
return func(params *testChannelParams) {
params.channel.LocalCommitment.Htlcs = htlcs
}
}
// remoteHtlcsOption is an option which allows setting of htlcs on the remote
// commitment.
func remoteHtlcsOption(htlcs []HTLC) testChannelOption {
return func(params *testChannelParams) {
params.channel.RemoteCommitment.Htlcs = htlcs
}
}
// localShutdownOption is an option which sets the local upfront shutdown
// script for the channel.
func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
return func(params *testChannelParams) {
params.channel.LocalShutdownScript = addr
}
}
// remoteShutdownOption is an option which sets the remote upfront shutdown
// script for the channel.
func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
return func(params *testChannelParams) {
params.channel.RemoteShutdownScript = addr
}
}
// fundingPointOption is an option which sets the funding outpoint of the
// channel.
func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
return func(params *testChannelParams) {
params.channel.FundingOutpoint = chanPoint
}
}
// channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
return func(params *testChannelParams) {
params.channel.ShortChannelID = chanID
}
}
// createTestChannel writes a test channel to the database. It takes a set of
// functional options which can be used to overwrite the default of creating
// a pending channel that was broadcast at height 100.
func createTestChannel(t *testing.T, cdb *DB,
opts ...testChannelOption) *OpenChannel {
// Create a default set of parameters.
params := &testChannelParams{
channel: createTestChannelState(t, cdb),
addr: defaultAddr,
openChannel: false,
pendingHeight: uint32(defaultPendingHeight),
}
// Apply all functional options to the test channel params.
for _, o := range opts {
o(params)
}
// Mark the channel as pending.
err := params.channel.SyncPending(params.addr, params.pendingHeight)
if err != nil {
t.Fatalf("unable to save and serialize channel "+
"state: %v", err)
}
// If the parameters do not specify that we should open the channel
// fully, we return the pending channel.
if !params.openChannel {
return params.channel
}
// Mark the channel as open with the short channel id provided.
err = params.channel.MarkAsOpen(params.channel.ShortChannelID)
if err != nil {
t.Fatalf("unable to mark channel open: %v", err)
}
return params.channel
}
func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel {
// Simulate 1000 channel updates. // Simulate 1000 channel updates.
producer, err := shachain.NewRevocationProducerFromBytes(key[:]) producer, err := shachain.NewRevocationProducerFromBytes(key[:])
if err != nil { if err != nil {
return nil, err t.Fatalf("could not get producer: %v", err)
} }
store := shachain.NewRevocationStore() store := shachain.NewRevocationStore()
for i := 0; i < 1; i++ { for i := 0; i < 1; i++ {
preImage, err := producer.AtIndex(uint64(i)) preImage, err := producer.AtIndex(uint64(i))
if err != nil { if err != nil {
return nil, err t.Fatalf("could not get "+
"preimage: %v", err)
} }
if err := store.AddNextEntry(preImage); err != nil { if err := store.AddNextEntry(preImage); err != nil {
return nil, err t.Fatalf("could not add entry: %v", err)
} }
} }
@@ -228,7 +366,7 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) {
Db: cdb, Db: cdb,
Packager: NewChannelPackager(chanID), Packager: NewChannelPackager(chanID),
FundingTxn: testTx, FundingTxn: testTx,
}, nil }
} }
func TestOpenChannelPutGetDelete(t *testing.T) { func TestOpenChannelPutGetDelete(t *testing.T) {
@@ -240,15 +378,10 @@ func TestOpenChannelPutGetDelete(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// Create the test channel state, then add an additional fake HTLC // Create the test channel state, with additional htlcs on the local
// before syncing to disk. // and remote commitment.
state, err := createTestChannelState(cdb) localHtlcs := []HTLC{
if err != nil { {Signature: testSig.Serialize(),
t.Fatalf("unable to create channel state: %v", err)
}
state.LocalCommitment.Htlcs = []HTLC{
{
Signature: testSig.Serialize(),
Incoming: true, Incoming: true,
Amt: 10, Amt: 10,
RHash: key, RHash: key,
@@ -256,7 +389,8 @@ func TestOpenChannelPutGetDelete(t *testing.T) {
OnionBlob: []byte("onionblob"), OnionBlob: []byte("onionblob"),
}, },
} }
state.RemoteCommitment.Htlcs = []HTLC{
remoteHtlcs := []HTLC{
{ {
Signature: testSig.Serialize(), Signature: testSig.Serialize(),
Incoming: false, Incoming: false,
@@ -267,13 +401,11 @@ func TestOpenChannelPutGetDelete(t *testing.T) {
}, },
} }
addr := &net.TCPAddr{ state := createTestChannel(
IP: net.ParseIP("127.0.0.1"), t, cdb,
Port: 18556, remoteHtlcsOption(remoteHtlcs),
} localHtlcsOption(localHtlcs),
if err := state.SyncPending(addr, 101); err != nil { )
t.Fatalf("unable to save and serialize channel state: %v", err)
}
openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
if err != nil { if err != nil {
@@ -360,36 +492,28 @@ func TestOptionalShutdown(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
modifyChannel func(channel *OpenChannel) localShutdown lnwire.DeliveryAddress
expectedLocal lnwire.DeliveryAddress remoteShutdown lnwire.DeliveryAddress
expectedRemote lnwire.DeliveryAddress
}{ }{
{ {
name: "no shutdown scripts", name: "no shutdown scripts",
modifyChannel: func(channel *OpenChannel) {}, localShutdown: nil,
remoteShutdown: nil,
}, },
{ {
name: "local shutdown script", name: "local shutdown script",
modifyChannel: func(channel *OpenChannel) { localShutdown: local,
channel.LocalShutdownScript = local remoteShutdown: nil,
},
expectedLocal: local,
}, },
{ {
name: "remote shutdown script", name: "remote shutdown script",
modifyChannel: func(channel *OpenChannel) { localShutdown: nil,
channel.RemoteShutdownScript = remote remoteShutdown: remote,
},
expectedRemote: remote,
}, },
{ {
name: "both scripts set", name: "both scripts set",
modifyChannel: func(channel *OpenChannel) { localShutdown: local,
channel.LocalShutdownScript = local remoteShutdown: remote,
channel.RemoteShutdownScript = remote
},
expectedLocal: local,
expectedRemote: remote,
}, },
} }
@@ -403,40 +527,40 @@ func TestOptionalShutdown(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// Create the test channel state, then add an additional fake HTLC // Create a channel with upfront scripts set as
// before syncing to disk. // specified in the test.
state, err := createTestChannelState(cdb) state := createTestChannel(
t, cdb,
localShutdownOption(test.localShutdown),
remoteShutdownOption(test.remoteShutdown),
)
openChannels, err := cdb.FetchOpenChannels(
state.IdentityPub,
)
if err != nil { if err != nil {
t.Fatalf("unable to create channel state: %v", err) t.Fatalf("unable to fetch open"+
} " channel: %v", err)
test.modifyChannel(state)
// Write channels to Db.
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := state.SyncPending(addr, 101); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
if err != nil {
t.Fatalf("unable to fetch open channel: %v", err)
} }
if len(openChannels) != 1 { if len(openChannels) != 1 {
t.Fatalf("Expected one channel open, got: %v", len(openChannels)) t.Fatalf("Expected one channel open,"+
" got: %v", len(openChannels))
} }
if !bytes.Equal(openChannels[0].LocalShutdownScript, test.expectedLocal) { if !bytes.Equal(openChannels[0].LocalShutdownScript,
t.Fatalf("Expected local: %x, got: %x", test.expectedLocal, test.localShutdown) {
t.Fatalf("Expected local: %x, got: %x",
test.localShutdown,
openChannels[0].LocalShutdownScript) openChannels[0].LocalShutdownScript)
} }
if !bytes.Equal(openChannels[0].RemoteShutdownScript, test.expectedRemote) { if !bytes.Equal(openChannels[0].RemoteShutdownScript,
t.Fatalf("Expected remote: %x, got: %x", test.expectedRemote, test.remoteShutdown) {
t.Fatalf("Expected remote: %x, got: %x",
test.remoteShutdown,
openChannels[0].RemoteShutdownScript) openChannels[0].RemoteShutdownScript)
} }
}) })
@@ -462,18 +586,7 @@ func TestChannelStateTransition(t *testing.T) {
// First create a minimal channel, then perform a full sync in order to // First create a minimal channel, then perform a full sync in order to
// persist the data. // persist the data.
channel, err := createTestChannelState(cdb) channel := createTestChannel(t, cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := channel.SyncPending(addr, 101); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Add some HTLCs which were added during this new state transition. // Add some HTLCs which were added during this new state transition.
// Half of the HTLCs are incoming, while the other half are outgoing. // Half of the HTLCs are incoming, while the other half are outgoing.
@@ -776,21 +889,9 @@ func TestFetchPendingChannels(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// Create first test channel state // Create a pending channel that was broadcast at height 99.
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
const broadcastHeight = 99 const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil { createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))
t.Fatalf("unable to save and serialize channel state: %v", err)
}
pendingChannels, err := cdb.FetchPendingChannels() pendingChannels, err := cdb.FetchPendingChannels()
if err != nil { if err != nil {
@@ -867,35 +968,8 @@ func TestFetchClosedChannels(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// First create a test channel, that we'll be closing within this pull // Create an open channel in the database.
// request. state := createTestChannel(t, cdb, openChannelOption())
state, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Next sync the channel to disk, marking it as being in a pending open
// state.
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Next, simulate the confirmation of the channel by marking it as
// pending within the database.
chanOpenLoc := lnwire.ShortChannelID{
BlockHeight: 5,
TxIndex: 10,
TxPosition: 15,
}
err = state.MarkAsOpen(chanOpenLoc)
if err != nil {
t.Fatalf("unable to mark channel as open: %v", err)
}
// Next, close the channel by including a close channel summary in the // Next, close the channel by including a close channel summary in the
// database. // database.
@@ -975,7 +1049,6 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
const numChannels = 2 const numChannels = 2
const broadcastHeight = 99 const broadcastHeight = 99
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555}
// We'll start by creating two channels within our test database. One of // We'll start by creating two channels within our test database. One of
// them will have their funding transaction confirmed on-chain, while // them will have their funding transaction confirmed on-chain, while
@@ -988,15 +1061,11 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
channels := make([]*OpenChannel, numChannels) channels := make([]*OpenChannel, numChannels)
for i := 0; i < numChannels; i++ { for i := 0; i < numChannels; i++ {
channel, err := createTestChannelState(db) // Create a pending channel in the database at the broadcast
if err != nil { // height.
t.Fatalf("unable to create channel: %v", err) channels[i] = createTestChannel(
} t, db, pendingHeightOption(broadcastHeight),
err = channel.SyncPending(addr, broadcastHeight) )
if err != nil {
t.Fatalf("unable to sync channel: %v", err)
}
channels[i] = channel
} }
// We'll only confirm the first one. // We'll only confirm the first one.
@@ -1106,21 +1175,7 @@ func TestRefreshShortChanID(t *testing.T) {
defer cleanUp() defer cleanUp()
// First create a test channel. // First create a test channel.
state, err := createTestChannelState(cdb) state := createTestChannel(t, cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
// Mark the channel as pending within the channeldb.
const broadcastHeight = 99
if err := state.SyncPending(addr, broadcastHeight); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
// Next, locate the pending channel with the database. // Next, locate the pending channel with the database.
pendingChannels, err := cdb.FetchPendingChannels() pendingChannels, err := cdb.FetchPendingChannels()

View File

@@ -556,42 +556,28 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
// within the database, including pending open, fully open and channels waiting // within the database, including pending open, fully open and channels waiting
// for a closing transaction to confirm. // for a closing transaction to confirm.
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { func (d *DB) FetchAllChannels() ([]*OpenChannel, error) {
var channels []*OpenChannel return fetchChannels(d)
// TODO(halseth): fetch all in one db tx.
openChannels, err := d.FetchAllOpenChannels()
if err != nil {
return nil, err
}
channels = append(channels, openChannels...)
pendingChannels, err := d.FetchPendingChannels()
if err != nil {
return nil, err
}
channels = append(channels, pendingChannels...)
waitingClose, err := d.FetchWaitingCloseChannels()
if err != nil {
return nil, err
}
channels = append(channels, waitingClose...)
return channels, nil
} }
// FetchAllOpenChannels will return all channels that have the funding // FetchAllOpenChannels will return all channels that have the funding
// transaction confirmed, and is not waiting for a closing transaction to be // transaction confirmed, and is not waiting for a closing transaction to be
// confirmed. // confirmed.
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
return fetchChannels(d, false, false) return fetchChannels(
d,
pendingChannelFilter(false),
waitingCloseFilter(false),
)
} }
// FetchPendingChannels will return channels that have completed the process of // FetchPendingChannels will return channels that have completed the process of
// generating and broadcasting funding transactions, but whose funding // generating and broadcasting funding transactions, but whose funding
// transactions have yet to be confirmed on the blockchain. // transactions have yet to be confirmed on the blockchain.
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(d, true, false) return fetchChannels(d,
pendingChannelFilter(true),
waitingCloseFilter(false),
)
} }
// FetchWaitingCloseChannels will return all channels that have been opened, // FetchWaitingCloseChannels will return all channels that have been opened,
@@ -599,25 +585,49 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
// //
// NOTE: This includes channels that are also pending to be opened. // NOTE: This includes channels that are also pending to be opened.
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
waitingClose, err := fetchChannels(d, false, true) return fetchChannels(
if err != nil { d, waitingCloseFilter(true),
return nil, err )
} }
pendingWaitingClose, err := fetchChannels(d, true, true)
if err != nil {
return nil, err
}
return append(waitingClose, pendingWaitingClose...), nil // fetchChannelsFilter applies a filter to channels retrieved in fetchchannels.
// A set of filters can be combined to filter across multiple dimensions.
type fetchChannelsFilter func(channel *OpenChannel) bool
// pendingChannelFilter returns a filter based on whether channels are pending
// (ie, their funding transaction still needs to confirm). If pending is false,
// channels with confirmed funding transactions are returned.
func pendingChannelFilter(pending bool) fetchChannelsFilter {
return func(channel *OpenChannel) bool {
return channel.IsPending == pending
}
}
// waitingCloseFilter returns a filter which filters channels based on whether
// they are awaiting the confirmation of their closing transaction. If waiting
// close is true, channels that have had their closing tx broadcast are
// included. If it is false, channels that are not awaiting confirmation of
// their close transaction are returned.
func waitingCloseFilter(waitingClose bool) fetchChannelsFilter {
return func(channel *OpenChannel) bool {
// If the channel is in any other state than Default,
// then it means it is waiting to be closed.
channelWaitingClose :=
channel.ChanStatus() != ChanStatusDefault
// Include the channel if it matches the value for
// waiting close that we are filtering on.
return channelWaitingClose == waitingClose
}
} }
// fetchChannels attempts to retrieve channels currently stored in the // fetchChannels attempts to retrieve channels currently stored in the
// database. The pending parameter determines whether only pending channels // database. It takes a set of filters which are applied to each channel to
// will be returned, or only open channels will be returned. The waitingClose // obtain a set of channels with the desired set of properties. Only channels
// parameter determines whether only channels waiting for a closing transaction // which have a true value returned for *all* of the filters will be returned.
// to be confirmed should be returned. If no active channels exist within the // If no filters are provided, every channel in the open channels bucket will
// network, then ErrNoActiveChannels is returned. // be returned.
func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) {
var channels []*OpenChannel var channels []*OpenChannel
err := d.View(func(tx *bbolt.Tx) error { err := d.View(func(tx *bbolt.Tx) error {
@@ -667,24 +677,27 @@ func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) {
"node_key=%x: %v", chainHash[:], k, err) "node_key=%x: %v", chainHash[:], k, err)
} }
for _, channel := range nodeChans { for _, channel := range nodeChans {
if channel.IsPending != pending { // includeChannel indicates whether the channel
continue // meets the criteria specified by our filters.
includeChannel := true
// Run through each filter and check whether the
// channel should be included.
for _, f := range filters {
// If the channel fails the filter, set
// includeChannel to false and don't bother
// checking the remaining filters.
if !f(channel) {
includeChannel = false
break
}
} }
// If the channel is in any other state // If the channel passed every filter, include it in
// than Default, then it means it is // our set of channels.
// waiting to be closed. if includeChannel {
channelWaitingClose := channels = append(channels, channel)
channel.ChanStatus() != ChanStatusDefault
// Only include it if we requested
// channels with the same waitingClose
// status.
if channelWaitingClose != waitingClose {
continue
} }
channels = append(channels, channel)
} }
return nil return nil
}) })

View File

@@ -100,10 +100,7 @@ func TestFetchClosedChannelForID(t *testing.T) {
// Create the test channel state, that we will mutate the index of the // Create the test channel state, that we will mutate the index of the
// funding point. // funding point.
state, err := createTestChannelState(cdb) state := createTestChannelState(t, cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Now run through the number of channels, and modify the outpoint index // Now run through the number of channels, and modify the outpoint index
// to create new channel IDs. // to create new channel IDs.
@@ -111,14 +108,12 @@ func TestFetchClosedChannelForID(t *testing.T) {
// Save the open channel to disk. // Save the open channel to disk.
state.FundingOutpoint.Index = i state.FundingOutpoint.Index = i
addr := &net.TCPAddr{ // Write the channel to disk in a pending state.
IP: net.ParseIP("127.0.0.1"), createTestChannel(
Port: 18556, t, cdb,
} fundingPointOption(state.FundingOutpoint),
if err := state.SyncPending(addr, 101); err != nil { openChannelOption(),
t.Fatalf("unable to save and serialize channel "+ )
"state: %v", err)
}
// Close the channel. To make sure we retrieve the correct // Close the channel. To make sure we retrieve the correct
// summary later, we make them differ in the SettledBalance. // summary later, we make them differ in the SettledBalance.
@@ -235,26 +230,8 @@ func TestFetchChannel(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// Create the test channel state that we'll sync to the database // Create an open channel.
// shortly. channelState := createTestChannel(t, cdb, openChannelOption())
channelState, err := createTestChannelState(cdb)
if err != nil {
t.Fatalf("unable to create channel state: %v", err)
}
// Mark the channel as pending, then immediately mark it as open to it
// can be fully visible.
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
if err := channelState.SyncPending(addr, 9); err != nil {
t.Fatalf("unable to save and serialize channel state: %v", err)
}
err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99))
if err != nil {
t.Fatalf("unable to mark channel open: %v", err)
}
// Next, attempt to fetch the channel by its chan point. // Next, attempt to fetch the channel by its chan point.
dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint)
@@ -271,7 +248,7 @@ func TestFetchChannel(t *testing.T) {
// If we attempt to query for a non-exist ante channel, then we should // If we attempt to query for a non-exist ante channel, then we should
// get an error. // get an error.
channelState2, err := createTestChannelState(cdb) channelState2 := createTestChannelState(t, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to create channel state: %v", err) t.Fatalf("unable to create channel state: %v", err)
} }
@@ -491,19 +468,9 @@ func TestAbandonChannel(t *testing.T) {
t.Fatalf("removing non-existent channel should have failed") t.Fatalf("removing non-existent channel should have failed")
} }
// We'll now create a new channel to abandon shortly. // We'll now create a new channel in a pending state to abandon
chanState, err := createTestChannelState(cdb) // shortly.
if err != nil { chanState := createTestChannel(t, cdb)
t.Fatalf("unable to create channel state: %v", err)
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
err = chanState.SyncPending(addr, 10)
if err != nil {
t.Fatalf("unable to sync pending channel: %v", err)
}
// We should now be able to abandon the channel without any errors. // We should now be able to abandon the channel without any errors.
closeHeight := uint32(11) closeHeight := uint32(11)
@@ -533,3 +500,196 @@ func TestAbandonChannel(t *testing.T) {
t.Fatalf("unable to abandon channel: %v", err) t.Fatalf("unable to abandon channel: %v", err)
} }
} }
// TestFetchChannels tests the filtering of open channels in fetchChannels.
// It tests the case where no filters are provided (which is equivalent to
// FetchAllOpenChannels) and every combination of pending and waiting close.
func TestFetchChannels(t *testing.T) {
// Create static channel IDs for each kind of channel retrieved by
// fetchChannels so that the expected channel IDs can be set in tests.
var (
// Pending is a channel that is pending open, and has not had
// a close initiated.
pendingChan = lnwire.NewShortChanIDFromInt(1)
// pendingWaitingClose is a channel that is pending open and
// has has its closing transaction broadcast.
pendingWaitingChan = lnwire.NewShortChanIDFromInt(2)
// openChan is a channel that has confirmed on chain.
openChan = lnwire.NewShortChanIDFromInt(3)
// openWaitingChan is a channel that has confirmed on chain,
// and it waiting for its close transaction to confirm.
openWaitingChan = lnwire.NewShortChanIDFromInt(4)
)
tests := []struct {
name string
filters []fetchChannelsFilter
expectedChannels map[lnwire.ShortChannelID]bool
}{
{
name: "get all channels",
filters: []fetchChannelsFilter{},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingChan: true,
pendingWaitingChan: true,
openChan: true,
openWaitingChan: true,
},
},
{
name: "pending channels",
filters: []fetchChannelsFilter{
pendingChannelFilter(true),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingChan: true,
pendingWaitingChan: true,
},
},
{
name: "open channels",
filters: []fetchChannelsFilter{
pendingChannelFilter(false),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
openChan: true,
openWaitingChan: true,
},
},
{
name: "waiting close channels",
filters: []fetchChannelsFilter{
waitingCloseFilter(true),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingWaitingChan: true,
openWaitingChan: true,
},
},
{
name: "not waiting close channels",
filters: []fetchChannelsFilter{
waitingCloseFilter(false),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingChan: true,
openChan: true,
},
},
{
name: "pending waiting",
filters: []fetchChannelsFilter{
pendingChannelFilter(true),
waitingCloseFilter(true),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingWaitingChan: true,
},
},
{
name: "pending, not waiting",
filters: []fetchChannelsFilter{
pendingChannelFilter(true),
waitingCloseFilter(false),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
pendingChan: true,
},
},
{
name: "open waiting",
filters: []fetchChannelsFilter{
pendingChannelFilter(false),
waitingCloseFilter(true),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
openWaitingChan: true,
},
},
{
name: "open, not waiting",
filters: []fetchChannelsFilter{
pendingChannelFilter(false),
waitingCloseFilter(false),
},
expectedChannels: map[lnwire.ShortChannelID]bool{
openChan: true,
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test "+
"database: %v", err)
}
defer cleanUp()
// Create a pending channel that is not awaiting close.
createTestChannel(
t, cdb, channelIDOption(pendingChan),
)
// Create a pending channel which has has been marked as
// broadcast, indicating that its closing transaction is
// waiting to confirm.
pendingClosing := createTestChannel(
t, cdb,
channelIDOption(pendingWaitingChan),
)
err = pendingClosing.MarkCoopBroadcasted(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Create a open channel that is not awaiting close.
createTestChannel(
t, cdb,
channelIDOption(openChan),
openChannelOption(),
)
// Create a open channel which has has been marked as
// broadcast, indicating that its closing transaction is
// waiting to confirm.
openClosing := createTestChannel(
t, cdb,
channelIDOption(openWaitingChan),
openChannelOption(),
)
err = openClosing.MarkCoopBroadcasted(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
channels, err := fetchChannels(cdb, test.filters...)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(channels) != len(test.expectedChannels) {
t.Fatalf("expected: %v channels, "+
"got: %v", len(test.expectedChannels),
len(channels))
}
for _, ch := range channels {
_, ok := test.expectedChannels[ch.ShortChannelID]
if !ok {
t.Fatalf("fetch channels unexpected "+
"channel: %v", ch.ShortChannelID)
}
}
})
}
}