From f89e3ceced158a8f2334d337a66099f5ecbb046e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 15:40:10 +0200 Subject: [PATCH] graph/db+sqldb: implement ForEachSourceNodeChannel In this commit, the ForEachSourceNodeChannel implementation of the SQLStore is added. Since this is the first method of the SQLStore that fetches channel and policy info, it also adds all the helpers that are required to do so. These will be re-used in upcoming commits as more "For"-type methods are added. With this implementation, we convert the `TestForEachSourceNodeChannel` such that it is run against SQL backends. --- docs/release-notes/release-notes-0.20.0.md | 1 + graph/db/graph_test.go | 2 +- graph/db/sql_store.go | 444 +++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 266 ++++++++++++ sqldb/sqlc/querier.go | 3 + sqldb/sqlc/queries/graph.sql | 82 ++++ 6 files changed, 797 insertions(+), 1 deletion(-) diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index b25b9e6d8..d5edcab17 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -76,6 +76,7 @@ circuit. The indices are only available for forwarding events saved after v0.20. * [1](https://github.com/lightningnetwork/lnd/pull/9866) * [2](https://github.com/lightningnetwork/lnd/pull/9869) * [3](https://github.com/lightningnetwork/lnd/pull/9887) + * [4](https://github.com/lightningnetwork/lnd/pull/9931) ## RPC Updates diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index b279f5476..8dab86611 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1183,7 +1183,7 @@ func TestAddEdgeProof(t *testing.T) { func TestForEachSourceNodeChannel(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // Create a source node (A) and set it as such in the DB. nodeA := createTestVertex(t) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 9bdad6e36..137c1d2e2 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -14,8 +14,11 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -77,7 +80,9 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) + GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) + ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error @@ -88,6 +93,7 @@ type SQLQueries interface { UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error) InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error + GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error } @@ -651,6 +657,136 @@ func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy, } } +// ForEachSourceNodeChannel iterates through all channels of the source node, +// executing the passed callback on each. The call-back is provided with the +// channel's outpoint, whether we have a policy for the channel and the channel +// peer's node information. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error) error { + + var ctx = context.TODO() + + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1) + if err != nil { + return fmt.Errorf("unable to fetch source node: %w", + err) + } + + return forEachNodeChannel( + ctx, db, s.cfg.ChainHash, nodeID, + func(info *models.ChannelEdgeInfo, + outPolicy *models.ChannelEdgePolicy, + _ *models.ChannelEdgePolicy) error { + + // Fetch the other node. + var ( + otherNodePub [33]byte + node1 = info.NodeKey1Bytes + node2 = info.NodeKey2Bytes + ) + switch { + case bytes.Equal(node1[:], nodePub[:]): + otherNodePub = node2 + case bytes.Equal(node2[:], nodePub[:]): + otherNodePub = node1 + default: + return fmt.Errorf("node not " + + "participating in this channel") + } + + _, otherNode, err := getNodeByPubKey( + ctx, db, otherNodePub, + ) + if err != nil { + return fmt.Errorf("unable to fetch "+ + "other node(%x): %w", + otherNodePub, err) + } + + return cb( + info.ChannelPoint, outPolicy != nil, + otherNode, + ) + }, + ) + }, sqldb.NoOpReset) +} + +// forEachNodeChannel iterates through all channels of a node, executing +// the passed callback on each. The call-back is provided with the channel's +// edge information, the outgoing policy and the incoming policy for the +// channel and node combo. +func forEachNodeChannel(ctx context.Context, db SQLQueries, + chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + // Get all the V1 channels for this node.Add commentMore actions + rows, err := db.ListChannelsByNodeID( + ctx, sqlc.ListChannelsByNodeIDParams{ + Version: int16(ProtocolV1), + NodeID1: id, + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch channels: %w", err) + } + + // Call the call-back for each channel and its known policies. + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1Pubkey, row.Node2Pubkey, + ) + if err != nil { + return fmt.Errorf("unable to build node vertices: %w", + err) + } + + edge, err := getAndBuildEdgeInfo( + ctx, db, chain, row.Channel.ID, row.Channel, node1, + node2, + ) + if err != nil { + return fmt.Errorf("unable to build channel info: %w", + err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + p1, p2, err := getAndBuildChanPolicies( + ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + // Determine the outgoing and incoming policy for this + // channel and node combo. + p1ToNode := row.Channel.NodeID2 + p2ToNode := row.Channel.NodeID1 + outPolicy, inPolicy := p1, p2 + if (p1 != nil && p1ToNode == id) || + (p2 != nil && p2ToNode != id) { + + outPolicy, inPolicy = p2, p1 + } + + if err := cb(edge, outPolicy, inPolicy); err != nil { + return err + } + } + + return nil +} + // updateChanEdgePolicy upserts the channel policy info we have stored for // a channel we already know of. func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, @@ -1515,3 +1651,311 @@ func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries, return nil } + +// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the +// provided dbChanRow and also fetches any other required information +// to construct the edge info. +func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries, + chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1, + node2 route.Vertex) (*models.ChannelEdgeInfo, error) { + + fv, extras, err := getChanFeaturesAndExtras( + ctx, db, dbChanID, + ) + if err != nil { + return nil, err + } + + op, err := wire.NewOutPointFromString(dbChan.Outpoint) + if err != nil { + return nil, err + } + + var featureBuf bytes.Buffer + if err := fv.Encode(&featureBuf); err != nil { + return nil, fmt.Errorf("unable to encode features: %w", err) + } + + recs, err := lnwire.CustomRecords(extras).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + if recs == nil { + recs = make([]byte, 0) + } + + var btcKey1, btcKey2 route.Vertex + copy(btcKey1[:], dbChan.BitcoinKey1) + copy(btcKey2[:], dbChan.BitcoinKey2) + + channel := &models.ChannelEdgeInfo{ + ChainHash: chain, + ChannelID: byteOrder.Uint64(dbChan.Scid), + NodeKey1Bytes: node1, + NodeKey2Bytes: node2, + BitcoinKey1Bytes: btcKey1, + BitcoinKey2Bytes: btcKey2, + ChannelPoint: *op, + Capacity: btcutil.Amount(dbChan.Capacity.Int64), + Features: featureBuf.Bytes(), + ExtraOpaqueData: recs, + } + + if dbChan.Bitcoin1Signature != nil { + channel.AuthProof = &models.ChannelAuthProof{ + NodeSig1Bytes: dbChan.Node1Signature, + NodeSig2Bytes: dbChan.Node2Signature, + BitcoinSig1Bytes: dbChan.Bitcoin1Signature, + BitcoinSig2Bytes: dbChan.Bitcoin2Signature, + } + } + + return channel, nil +} + +// buildNodeVertices is a helper that converts raw node public keys +// into route.Vertex instances. +func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex, + route.Vertex, error) { + + node1Vertex, err := route.NewVertexFromBytes(node1Pub) + if err != nil { + return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+ + "create vertex from node1 pubkey: %w", err) + } + + node2Vertex, err := route.NewVertexFromBytes(node2Pub) + if err != nil { + return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+ + "create vertex from node2 pubkey: %w", err) + } + + return node1Vertex, node2Vertex, nil +} + +// getChanFeaturesAndExtras fetches the channel features and extra TLV types +// for a channel with the given ID. +func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries, + id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) { + + rows, err := db.GetChannelFeaturesAndExtras(ctx, id) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch channel "+ + "features and extras: %w", err) + } + + var ( + fv = lnwire.EmptyFeatureVector() + extras = make(map[uint64][]byte) + ) + for _, row := range rows { + if row.IsFeature { + fv.Set(lnwire.FeatureBit(row.FeatureBit)) + + continue + } + + tlvType, ok := row.ExtraKey.(int64) + if !ok { + return nil, nil, fmt.Errorf("unexpected type for "+ + "TLV type: %T", row.ExtraKey) + } + + valueBytes, ok := row.Value.([]byte) + if !ok { + return nil, nil, fmt.Errorf("unexpected type for "+ + "Value: %T", row.Value) + } + + extras[uint64(tlvType)] = valueBytes + } + + return fv, extras, nil +} + +// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves +// all the extra info required to build the complete models.ChannelEdgePolicy +// types. It returns two policies, which may be nil if the provided +// sqlc.ChannelPolicy records are nil. +func getAndBuildChanPolicies(ctx context.Context, db SQLQueries, + dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1, + node2 route.Vertex) (*models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + if dbPol1 == nil && dbPol2 == nil { + return nil, nil, nil + } + + var ( + policy1ID int64 + policy2ID int64 + ) + if dbPol1 != nil { + policy1ID = dbPol1.ID + } + if dbPol2 != nil { + policy2ID = dbPol2.ID + } + rows, err := db.GetChannelPolicyExtraTypes( + ctx, sqlc.GetChannelPolicyExtraTypesParams{ + ID: policy1ID, + ID_2: policy2ID, + }, + ) + if err != nil { + return nil, nil, err + } + + var ( + dbPol1Extras = make(map[uint64][]byte) + dbPol2Extras = make(map[uint64][]byte) + ) + for _, row := range rows { + switch row.PolicyID { + case policy1ID: + dbPol1Extras[uint64(row.Type)] = row.Value + case policy2ID: + dbPol2Extras[uint64(row.Type)] = row.Value + default: + return nil, nil, fmt.Errorf("unexpected policy ID %d "+ + "in row: %v", row.PolicyID, row) + } + } + + var pol1, pol2 *models.ChannelEdgePolicy + if dbPol1 != nil { + pol1, err = buildChanPolicy( + *dbPol1, channelID, dbPol1Extras, node2, true, + ) + if err != nil { + return nil, nil, err + } + } + if dbPol2 != nil { + pol2, err = buildChanPolicy( + *dbPol2, channelID, dbPol2Extras, node1, false, + ) + if err != nil { + return nil, nil, err + } + } + + return pol1, pol2, nil +} + +// buildChanPolicy builds a models.ChannelEdgePolicy instance from the +// provided sqlc.ChannelPolicy and other required information. +func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64, + extras map[uint64][]byte, toNode route.Vertex, + isNode1 bool) (*models.ChannelEdgePolicy, error) { + + recs, err := lnwire.CustomRecords(extras).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + + var msgFlags lnwire.ChanUpdateMsgFlags + if dbPolicy.MaxHtlcMsat.Valid { + msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc + } + + var chanFlags lnwire.ChanUpdateChanFlags + if !isNode1 { + chanFlags |= lnwire.ChanUpdateDirection + } + if dbPolicy.Disabled.Bool { + chanFlags |= lnwire.ChanUpdateDisabled + } + + var inboundFee fn.Option[lnwire.Fee] + if dbPolicy.InboundFeeRateMilliMsat.Valid || + dbPolicy.InboundBaseFeeMsat.Valid { + + inboundFee = fn.Some(lnwire.Fee{ + BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64), + FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64), + }) + } + + return &models.ChannelEdgePolicy{ + SigBytes: dbPolicy.Signature, + ChannelID: channelID, + LastUpdate: time.Unix( + dbPolicy.LastUpdate.Int64, 0, + ), + MessageFlags: msgFlags, + ChannelFlags: chanFlags, + TimeLockDelta: uint16(dbPolicy.Timelock), + MinHTLC: lnwire.MilliSatoshi( + dbPolicy.MinHtlcMsat, + ), + MaxHTLC: lnwire.MilliSatoshi( + dbPolicy.MaxHtlcMsat.Int64, + ), + FeeBaseMSat: lnwire.MilliSatoshi( + dbPolicy.BaseFeeMsat, + ), + FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm), + ToNode: toNode, + InboundFee: inboundFee, + ExtraOpaqueData: recs, + }, nil +} + +// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give +// row which is expected to be a sqlc type that contains channel policy +// information. It returns two policies, which may be nil if the policy +// information is not present in the row. +// +//nolint:ll +func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy, + error) { + + var policy1, policy2 *sqlc.ChannelPolicy + switch r := row.(type) { + case sqlc.ListChannelsByNodeIDRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.ChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.ChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + default: + return nil, nil, fmt.Errorf("unexpected row type in "+ + "extractChannelPolicies: %T", r) + } +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 82385008d..884d5c340 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -257,6 +257,120 @@ func (q *Queries) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDPara return i, err } +const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many +SELECT + cf.channel_id, + true AS is_feature, + cf.feature_bit AS feature_bit, + NULL AS extra_key, + NULL AS value +FROM channel_features cf +WHERE cf.channel_id = $1 + +UNION ALL + +SELECT + cet.channel_id, + false AS is_feature, + 0 AS feature_bit, + cet.type AS extra_key, + cet.value AS value +FROM channel_extra_types cet +WHERE cet.channel_id = $1 +` + +type GetChannelFeaturesAndExtrasRow struct { + ChannelID int64 + IsFeature bool + FeatureBit int32 + ExtraKey interface{} + Value interface{} +} + +func (q *Queries) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelFeaturesAndExtras, channelID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelFeaturesAndExtrasRow + for rows.Next() { + var i GetChannelFeaturesAndExtrasRow + if err := rows.Scan( + &i.ChannelID, + &i.IsFeature, + &i.FeatureBit, + &i.ExtraKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelPolicyExtraTypes = `-- name: GetChannelPolicyExtraTypes :many +SELECT + cp.id AS policy_id, + cp.channel_id, + cp.node_id, + cpet.type, + cpet.value +FROM channel_policies cp +JOIN channel_policy_extra_types cpet +ON cp.id = cpet.channel_policy_id +WHERE cp.id = $1 OR cp.id = $2 +` + +type GetChannelPolicyExtraTypesParams struct { + ID int64 + ID_2 int64 +} + +type GetChannelPolicyExtraTypesRow struct { + PolicyID int64 + ChannelID int64 + NodeID int64 + Type int64 + Value []byte +} + +func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelPolicyExtraTypes, arg.ID, arg.ID_2) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelPolicyExtraTypesRow + for rows.Next() { + var i GetChannelPolicyExtraTypesRow + if err := rows.Scan( + &i.PolicyID, + &i.ChannelID, + &i.NodeID, + &i.Type, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM node_extra_types @@ -614,6 +728,158 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa return err } +const listChannelsByNodeID = `-- name: ListChannelsByNodeID :many + +SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 = $2 OR c.node_id_2 = $2) +` + +type ListChannelsByNodeIDParams struct { + Version int16 + NodeID1 int64 +} + +type ListChannelsByNodeIDRow struct { + Channel Channel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2Signature []byte +} + +func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) { + rows, err := q.db.QueryContext(ctx, listChannelsByNodeID, arg.Version, arg.NodeID1) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChannelsByNodeIDRow + for rows.Next() { + var i ListChannelsByNodeIDRow + if err := rows.Scan( + &i.Channel.ID, + &i.Channel.Version, + &i.Channel.Scid, + &i.Channel.NodeID1, + &i.Channel.NodeID2, + &i.Channel.Outpoint, + &i.Channel.Capacity, + &i.Channel.BitcoinKey1, + &i.Channel.BitcoinKey2, + &i.Channel.Node1Signature, + &i.Channel.Node2Signature, + &i.Channel.Bitcoin1Signature, + &i.Channel.Bitcoin2Signature, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one /* ───────────────────────────────────────────── channel_policies table queries diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index e7b225204..d4a6a9dd0 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -29,6 +29,8 @@ type Querier interface { GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) + GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) + GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields @@ -61,6 +63,7 @@ type Querier interface { InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error + ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 85000ec8f..c6dbb6e6b 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -165,6 +165,27 @@ FROM channels c WHERE c.scid = $1 AND c.version = $2; +-- name: GetChannelFeaturesAndExtras :many +SELECT + cf.channel_id, + true AS is_feature, + cf.feature_bit AS feature_bit, + NULL AS extra_key, + NULL AS value +FROM channel_features cf +WHERE cf.channel_id = $1 + +UNION ALL + +SELECT + cet.channel_id, + false AS is_feature, + 0 AS feature_bit, + cet.type AS extra_key, + cet.value AS value +FROM channel_extra_types cet +WHERE cet.channel_id = $1; + -- name: HighestSCID :one SELECT scid FROM channels @@ -172,6 +193,55 @@ WHERE version = $1 ORDER BY scid DESC LIMIT 1; +-- name: ListChannelsByNodeID :many + +SELECT sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 = $2 OR c.node_id_2 = $2); + /* ───────────────────────────────────────────── channel_features table queries ───────────────────────────────────────────── @@ -237,6 +307,18 @@ INSERT INTO channel_policy_extra_types ( ) VALUES ($1, $2, $3); +-- name: GetChannelPolicyExtraTypes :many +SELECT + cp.id AS policy_id, + cp.channel_id, + cp.node_id, + cpet.type, + cpet.value +FROM channel_policies cp +JOIN channel_policy_extra_types cpet +ON cp.id = cpet.channel_policy_id +WHERE cp.id = $1 OR cp.id = $2; + -- name: DeleteChannelPolicyExtraTypes :exec DELETE FROM channel_policy_extra_types WHERE channel_policy_id = $1;