diff --git a/channeldb/db.go b/channeldb/db.go index 26268eabe..146c07449 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -674,6 +674,54 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( return c.channelScanner(tx, selector) } +// FetchChannelByID attempts to locate a channel specified by the passed channel +// ID. If the channel cannot be found, then an error will be returned. +// Optionally an existing db tx can be supplied. +func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) ( + *OpenChannel, error) { + + selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) { + + var ( + targetChanPointBytes []byte + targetChanPoint *wire.OutPoint + + // errChanFound is used to signal that the channel has + // been found so that iteration through the DB buckets + // can stop. + errChanFound = errors.New("channel found") + ) + err := chainBkt.ForEach(func(k, _ []byte) error { + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(k), &outPoint) + if err != nil { + return err + } + + chanID := lnwire.NewChanIDFromOutPoint(&outPoint) + if chanID != id { + return nil + } + + targetChanPoint = &outPoint + targetChanPointBytes = k + + return errChanFound + }) + if err != nil && !errors.Is(err, errChanFound) { + return nil, nil, err + } + if targetChanPoint == nil { + return nil, nil, ErrChannelNotFound + } + + return targetChanPointBytes, targetChanPoint, nil + } + + return c.channelScanner(tx, selector) +} + // channelSelector describes a function that takes a chain-hash bucket from // within the open-channel DB and returns the wanted channel point bytes, and // channel point. It must return the ErrChannelNotFound error if the wanted diff --git a/channeldb/db_test.go b/channeldb/db_test.go index bed9c6e32..62e720407 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -12,7 +12,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) { // The decoded channel state should be identical to what we stored // above. - if !reflect.DeepEqual(channelState, dbChannel) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(channelState), spew.Sdump(dbChannel)) - } + require.Equal(t, channelState, dbChannel) + + // Next, attempt to fetch the channel by its channel ID. + chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint) + dbChannel, err = cdb.FetchChannelByID(nil, chanID) + require.NoError(t, err, "unable to fetch channel") + + // The decoded channel state should be identical to what we stored + // above. + require.Equal(t, channelState, dbChannel) // If we attempt to query for a non-existent channel, then we should // get an error. @@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) { channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load() _, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint) - if err == nil { - t.Fatalf("expected query to fail") - } + require.ErrorIs(t, err, ErrChannelNotFound) + + chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint) + _, err = cdb.FetchChannelByID(nil, chanID2) + require.ErrorIs(t, err, ErrChannelNotFound) } func genRandomChannelShell() (*ChannelShell, error) {