channeldb: add a FetchChannelByID method

Add a FetchChannelByID method that allows a caller to fetch an
OpenChannel using an lnwire.ChannelID.
This commit is contained in:
Elle Mouton 2023-03-22 09:35:59 +02:00
parent 63442cbe51
commit fe2304efad
No known key found for this signature in database
GPG Key ID: D7D916376026F177
2 changed files with 63 additions and 8 deletions

View File

@ -674,6 +674,54 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
return c.channelScanner(tx, selector)
}
// FetchChannelByID attempts to locate a channel specified by the passed channel
// ID. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) (
*OpenChannel, error) {
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
var (
targetChanPointBytes []byte
targetChanPoint *wire.OutPoint
// errChanFound is used to signal that the channel has
// been found so that iteration through the DB buckets
// can stop.
errChanFound = errors.New("channel found")
)
err := chainBkt.ForEach(func(k, _ []byte) error {
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(k), &outPoint)
if err != nil {
return err
}
chanID := lnwire.NewChanIDFromOutPoint(&outPoint)
if chanID != id {
return nil
}
targetChanPoint = &outPoint
targetChanPointBytes = k
return errChanFound
})
if err != nil && !errors.Is(err, errChanFound) {
return nil, nil, err
}
if targetChanPoint == nil {
return nil, nil, ErrChannelNotFound
}
return targetChanPointBytes, targetChanPoint, nil
}
return c.channelScanner(tx, selector)
}
// channelSelector describes a function that takes a chain-hash bucket from
// within the open-channel DB and returns the wanted channel point bytes, and
// channel point. It must return the ErrChannelNotFound error if the wanted

View File

@ -12,7 +12,6 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) {
// The decoded channel state should be identical to what we stored
// above.
if !reflect.DeepEqual(channelState, dbChannel) {
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(channelState), spew.Sdump(dbChannel))
}
require.Equal(t, channelState, dbChannel)
// Next, attempt to fetch the channel by its channel ID.
chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint)
dbChannel, err = cdb.FetchChannelByID(nil, chanID)
require.NoError(t, err, "unable to fetch channel")
// The decoded channel state should be identical to what we stored
// above.
require.Equal(t, channelState, dbChannel)
// If we attempt to query for a non-existent channel, then we should
// get an error.
@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) {
channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load()
_, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint)
if err == nil {
t.Fatalf("expected query to fail")
}
require.ErrorIs(t, err, ErrChannelNotFound)
chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint)
_, err = cdb.FetchChannelByID(nil, chanID2)
require.ErrorIs(t, err, ErrChannelNotFound)
}
func genRandomChannelShell() (*ChannelShell, error) {