invoicesrpc: remove invoicerpc server's access to ChannelGraph pointer

Define a new GraphSource interface required by the invoicerpc server and
remove its access to the graphdb.ChannelGraph pointer. Add the new
invoicesrpc.GraphSource interface to the GraphSource interface
and let DBSource implement it.
This commit is contained in:
Elle Mouton 2024-11-11 16:43:38 +02:00
parent 9854bad720
commit 6f3d45f5d9
No known key found for this signature in database
GPG Key ID: D7D916376026F177
11 changed files with 124 additions and 60 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/graph/session"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@ -82,6 +83,28 @@ func (s *DBSource) FetchNodeFeatures(_ context.Context, tx session.RTx,
return s.db.FetchNodeFeatures(kvdbTx, node)
}
// FetchChannelEdgesByID attempts to look up the two directed edges for the
// channel identified by the channel ID. If the channel can't be found, then
// graphdb.ErrEdgeNotFound is returned.
//
// NOTE: this is part of the invoicesrpc.GraphSource interface.
func (s *DBSource) FetchChannelEdgesByID(_ context.Context,
chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
return s.db.FetchChannelEdgesByID(chanID)
}
// IsPublicNode determines whether the node with the given public key is seen as
// a public node in the graph from the graph's source node's point of view.
//
// NOTE: this is part of the invoicesrpc.GraphSource interface.
func (s *DBSource) IsPublicNode(_ context.Context,
pubKey [33]byte) (bool, error) {
return s.db.IsPublicNode(pubKey)
}
// kvdbRTx is an implementation of graphdb.RTx backed by a KVDB database read
// transaction.
type kvdbRTx struct {

View File

@ -1,9 +1,13 @@
package sources
import "github.com/lightningnetwork/lnd/graph/session"
import (
"github.com/lightningnetwork/lnd/graph/session"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
)
// GraphSource defines the read-only graph interface required by LND for graph
// related queries.
type GraphSource interface {
session.ReadOnlyGraph
invoicesrpc.GraphSource
}

View File

@ -18,7 +18,6 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes"
@ -75,8 +74,9 @@ type AddInvoiceConfig struct {
// channel graph.
ChanDB *channeldb.ChannelStateDB
// Graph holds a reference to the ChannelGraph database.
Graph *graphdb.ChannelGraph
// Graph holds a reference to a GraphSource that can be queried for
// graph related data.
Graph GraphSource
// GenInvoiceFeatures returns a feature containing feature bits that
// should be advertised on freshly generated invoices.
@ -627,6 +627,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) (
*models.ChannelEdgePolicy, bool) {
ctx := context.TODO()
// Since we're only interested in our private channels, we'll skip
// public ones.
if channel.IsPublic {
@ -648,7 +650,7 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) (
// channels.
var remotePub [33]byte
copy(remotePub[:], channel.RemotePubkey.SerializeCompressed())
isRemoteNodePublic, err := cfg.IsPublicNode(remotePub)
isRemoteNodePublic, err := cfg.IsPublicNode(ctx, remotePub)
if err != nil {
log.Errorf("Unable to determine if node %x "+
"is advertised: %v", remotePub, err)
@ -663,13 +665,17 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) (
}
// Fetch the policies for each end of the channel.
info, p1, p2, err := cfg.FetchChannelEdgesByID(channel.ShortChannelID)
info, p1, p2, err := cfg.FetchChannelEdgesByID(
ctx, channel.ShortChannelID,
)
if err != nil {
// In the case of zero-conf channels, it may be the case that
// the alias SCID was deleted from the graph, and replaced by
// the confirmed SCID. Check the Graph for the confirmed SCID.
confirmedScid := channel.ConfirmedScidZC
info, p1, p2, err = cfg.FetchChannelEdgesByID(confirmedScid)
info, p1, p2, err = cfg.FetchChannelEdgesByID(
ctx, confirmedScid,
)
if err != nil {
log.Errorf("Unable to fetch the routing policies for "+
"the edges of the channel %v: %v",
@ -759,13 +765,13 @@ type SelectHopHintsCfg struct {
// IsPublicNode is returns a bool indicating whether the node with the
// given public key is seen as a public node in the graph from the
// graph's source node's point of view.
IsPublicNode func(pubKey [33]byte) (bool, error)
IsPublicNode func(ctx context.Context, pubKey [33]byte) (bool, error)
// FetchChannelEdgesByID attempts to lookup the two directed edges for
// the channel identified by the channel ID.
FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy,
error)
FetchChannelEdgesByID func(ctx context.Context,
chanID uint64) (*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error)
// GetAlias allows the peer's alias SCID to be retrieved for private
// option_scid_alias channels.

View File

@ -1,6 +1,7 @@
package invoicesrpc
import (
"context"
"encoding/hex"
"fmt"
"testing"
@ -35,8 +36,10 @@ func newHopHintsConfigMock(t *testing.T) *hopHintsConfigMock {
}
// IsPublicNode mocks node public state lookup.
func (h *hopHintsConfigMock) IsPublicNode(pubKey [33]byte) (bool, error) {
args := h.Mock.Called(pubKey)
func (h *hopHintsConfigMock) IsPublicNode(ctx context.Context,
pubKey [33]byte) (bool, error) {
args := h.Mock.Called(ctx, pubKey)
return args.Bool(0), args.Error(1)
}
@ -66,11 +69,11 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel,
// FetchChannelEdgesByID attempts to lookup the two directed edges for
// the channel identified by the channel ID.
func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
func (h *hopHintsConfigMock) FetchChannelEdgesByID(ctx context.Context,
chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
args := h.Mock.Called(chanID)
args := h.Mock.Called(ctx, chanID)
// If our error is non-nil, we expect nil responses otherwise. Our
// casts below will fail with nil values, so we check our error and
@ -161,7 +164,7 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(false, nil)
},
channel: &channeldb.OpenChannel{
@ -185,18 +188,18 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(nil, nil, nil, fmt.Errorf("no edge"))
// TODO(positiveblue): check that the func is called with the
// right scid when we have access to the `confirmedscid` form
// here.
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(nil, nil, nil, fmt.Errorf("no edge"))
},
channel: &channeldb.OpenChannel{
@ -220,11 +223,11 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -258,11 +261,11 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -296,14 +299,14 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
var selectedPolicy [33]byte
copy(selectedPolicy[:], getTestPubKey().SerializeCompressed())
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{
NodeKey1Bytes: selectedPolicy,
@ -347,11 +350,11 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -392,11 +395,11 @@ var shouldIncludeChannelTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -559,11 +562,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -609,11 +612,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -660,11 +663,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -693,11 +696,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -710,11 +713,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -747,11 +750,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -764,11 +767,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},
@ -802,11 +805,11 @@ var populateHopHintsTestCases = []struct {
).Once().Return(true)
h.Mock.On(
"IsPublicNode", mock.Anything,
"IsPublicNode", mock.Anything, mock.Anything,
).Once().Return(true, nil)
h.Mock.On(
"FetchChannelEdgesByID", mock.Anything,
"FetchChannelEdgesByID", mock.Anything, mock.Anything,
).Once().Return(
&models.ChannelEdgeInfo{},
&models.ChannelEdgePolicy{},

View File

@ -6,7 +6,6 @@ package invoicesrpc
import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/macaroons"
@ -52,9 +51,8 @@ type Config struct {
// specified.
DefaultCLTVExpiry uint32
// GraphDB is a global database instance which is needed to access the
// channel graph.
GraphDB *graphdb.ChannelGraph
// Graph can be used for graph related queries.
Graph GraphSource
// ChanStateDB is a possibly replicated db instance which contains the
// channel state

View File

@ -0,0 +1,22 @@
package invoicesrpc
import (
"context"
"github.com/lightningnetwork/lnd/graph/db/models"
)
// GraphSource defines the graph interface required by the invoice rpc server.
type GraphSource interface {
// FetchChannelEdgesByID attempts to look up the two directed edges for
// the channel identified by the channel ID. If the channel can't be
// found, then graphdb.ErrEdgeNotFound is returned.
FetchChannelEdgesByID(ctx context.Context, chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error)
// IsPublicNode is a helper method that determines whether the node with
// the given public key is seen as a public node in the graph from the
// graph's source node's point of view.
IsPublicNode(ctx context.Context, pubKey [33]byte) (bool, error)
}

View File

@ -346,7 +346,7 @@ func (s *Server) AddHoldInvoice(ctx context.Context,
NodeSigner: s.cfg.NodeSigner,
DefaultCLTVExpiry: s.cfg.DefaultCLTVExpiry,
ChanDB: s.cfg.ChanStateDB,
Graph: s.cfg.GraphDB,
Graph: s.cfg.Graph,
GenInvoiceFeatures: s.cfg.GenInvoiceFeatures,
GenAmpInvoiceFeatures: s.cfg.GenAmpInvoiceFeatures,
GetAlias: s.cfg.GetAlias,

View File

@ -2,6 +2,7 @@ package blindedpath
import (
"bytes"
"context"
"errors"
"fmt"
"math"
@ -42,7 +43,8 @@ type BuildBlindedPathCfg struct {
// FetchChannelEdgesByID attempts to look up the two directed edges for
// the channel identified by the channel ID.
FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo,
FetchChannelEdgesByID func(ctx context.Context,
chanID uint64) (*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error)
// FetchOurOpenChannels fetches this node's set of open channels.
@ -643,7 +645,9 @@ func getNodeChannelPolicy(cfg *BuildBlindedPathCfg, chanID uint64,
// Attempt to fetch channel updates for the given channel. We will have
// at most two updates for a given channel.
_, update1, update2, err := cfg.FetchChannelEdgesByID(chanID)
_, update1, update2, err := cfg.FetchChannelEdgesByID(
context.TODO(), chanID,
)
if err != nil {
return nil, err
}

View File

@ -2,6 +2,7 @@ package blindedpath
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"math/rand"
@ -597,7 +598,7 @@ func TestBuildBlindedPath(t *testing.T) {
return []*route.Route{realRoute}, nil
},
FetchChannelEdgesByID: func(chanID uint64) (
FetchChannelEdgesByID: func(_ context.Context, chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
@ -765,7 +766,7 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
return []*route.Route{realRoute}, nil
},
FetchChannelEdgesByID: func(chanID uint64) (
FetchChannelEdgesByID: func(_ context.Context, chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
@ -936,7 +937,7 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
return []*route.Route{realRoute, realRoute, realRoute},
nil
},
FetchChannelEdgesByID: func(chanID uint64) (
FetchChannelEdgesByID: func(_ context.Context, chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {

View File

@ -796,7 +796,8 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service,
err = subServerCgs.PopulateDependencies(
r.cfg, s.cc, r.cfg.networkDir, macService, atpl, invoiceRegistry,
s.htlcSwitch, r.cfg.ActiveNetParams.Params, s.chanRouter,
routerBackend, s.nodeSigner, s.graphDB, s.chanStateDB,
routerBackend, s.nodeSigner, s.graphDB, s.graphSource,
s.chanStateDB,
s.sweeper, tower, s.towerClientMgr, r.cfg.net.ResolveTCPAddr,
genInvoiceFeatures, genAmpInvoiceFeatures,
s.getNodeAnnouncement, s.updateAndBroadcastSelfNode, parseAddr,
@ -6096,7 +6097,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
NodeSigner: r.server.nodeSigner,
DefaultCLTVExpiry: defaultDelta,
ChanDB: r.server.chanStateDB,
Graph: r.server.graphDB,
Graph: r.server.graphSource,
GenInvoiceFeatures: func() *lnwire.FeatureVector {
v := r.server.featureMgr.Get(feature.SetInvoice)

View File

@ -13,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
graphdb "github.com/lightningnetwork/lnd/graph/db"
graphsources "github.com/lightningnetwork/lnd/graph/sources"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lncfg"
@ -114,6 +115,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config,
routerBackend *routerrpc.RouterBackend,
nodeSigner *netann.NodeSigner,
graphDB *graphdb.ChannelGraph,
graphSource graphsources.GraphSource,
chanStateDB *channeldb.ChannelStateDB,
sweeper *sweep.UtxoSweeper,
tower *watchtower.Standalone,
@ -262,8 +264,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config,
subCfgValue.FieldByName("DefaultCLTVExpiry").Set(
reflect.ValueOf(defaultDelta),
)
subCfgValue.FieldByName("GraphDB").Set(
reflect.ValueOf(graphDB),
subCfgValue.FieldByName("Graph").Set(
reflect.ValueOf(graphSource),
)
subCfgValue.FieldByName("ChanStateDB").Set(
reflect.ValueOf(chanStateDB),