diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index b25b9e6d8..d5edcab17 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -76,6 +76,7 @@ circuit. The indices are only available for forwarding events saved after v0.20. * [1](https://github.com/lightningnetwork/lnd/pull/9866) * [2](https://github.com/lightningnetwork/lnd/pull/9869) * [3](https://github.com/lightningnetwork/lnd/pull/9887) + * [4](https://github.com/lightningnetwork/lnd/pull/9931) ## RPC Updates diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index b279f5476..8dab86611 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1183,7 +1183,7 @@ func TestAddEdgeProof(t *testing.T) { func TestForEachSourceNodeChannel(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // Create a source node (A) and set it as such in the DB. nodeA := createTestVertex(t) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 9bdad6e36..137c1d2e2 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -14,8 +14,11 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -77,7 +80,9 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) + GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) + ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error @@ -88,6 +93,7 @@ type SQLQueries interface { UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error) InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error + GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error } @@ -651,6 +657,136 @@ func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy, } } +// ForEachSourceNodeChannel iterates through all channels of the source node, +// executing the passed callback on each. The call-back is provided with the +// channel's outpoint, whether we have a policy for the channel and the channel +// peer's node information. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error) error { + + var ctx = context.TODO() + + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1) + if err != nil { + return fmt.Errorf("unable to fetch source node: %w", + err) + } + + return forEachNodeChannel( + ctx, db, s.cfg.ChainHash, nodeID, + func(info *models.ChannelEdgeInfo, + outPolicy *models.ChannelEdgePolicy, + _ *models.ChannelEdgePolicy) error { + + // Fetch the other node. + var ( + otherNodePub [33]byte + node1 = info.NodeKey1Bytes + node2 = info.NodeKey2Bytes + ) + switch { + case bytes.Equal(node1[:], nodePub[:]): + otherNodePub = node2 + case bytes.Equal(node2[:], nodePub[:]): + otherNodePub = node1 + default: + return fmt.Errorf("node not " + + "participating in this channel") + } + + _, otherNode, err := getNodeByPubKey( + ctx, db, otherNodePub, + ) + if err != nil { + return fmt.Errorf("unable to fetch "+ + "other node(%x): %w", + otherNodePub, err) + } + + return cb( + info.ChannelPoint, outPolicy != nil, + otherNode, + ) + }, + ) + }, sqldb.NoOpReset) +} + +// forEachNodeChannel iterates through all channels of a node, executing +// the passed callback on each. The call-back is provided with the channel's +// edge information, the outgoing policy and the incoming policy for the +// channel and node combo. +func forEachNodeChannel(ctx context.Context, db SQLQueries, + chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + // Get all the V1 channels for this node.Add commentMore actions + rows, err := db.ListChannelsByNodeID( + ctx, sqlc.ListChannelsByNodeIDParams{ + Version: int16(ProtocolV1), + NodeID1: id, + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch channels: %w", err) + } + + // Call the call-back for each channel and its known policies. + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1Pubkey, row.Node2Pubkey, + ) + if err != nil { + return fmt.Errorf("unable to build node vertices: %w", + err) + } + + edge, err := getAndBuildEdgeInfo( + ctx, db, chain, row.Channel.ID, row.Channel, node1, + node2, + ) + if err != nil { + return fmt.Errorf("unable to build channel info: %w", + err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + p1, p2, err := getAndBuildChanPolicies( + ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + // Determine the outgoing and incoming policy for this + // channel and node combo. + p1ToNode := row.Channel.NodeID2 + p2ToNode := row.Channel.NodeID1 + outPolicy, inPolicy := p1, p2 + if (p1 != nil && p1ToNode == id) || + (p2 != nil && p2ToNode != id) { + + outPolicy, inPolicy = p2, p1 + } + + if err := cb(edge, outPolicy, inPolicy); err != nil { + return err + } + } + + return nil +} + // updateChanEdgePolicy upserts the channel policy info we have stored for // a channel we already know of. func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, @@ -1515,3 +1651,311 @@ func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries, return nil } + +// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the +// provided dbChanRow and also fetches any other required information +// to construct the edge info. +func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries, + chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1, + node2 route.Vertex) (*models.ChannelEdgeInfo, error) { + + fv, extras, err := getChanFeaturesAndExtras( + ctx, db, dbChanID, + ) + if err != nil { + return nil, err + } + + op, err := wire.NewOutPointFromString(dbChan.Outpoint) + if err != nil { + return nil, err + } + + var featureBuf bytes.Buffer + if err := fv.Encode(&featureBuf); err != nil { + return nil, fmt.Errorf("unable to encode features: %w", err) + } + + recs, err := lnwire.CustomRecords(extras).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + if recs == nil { + recs = make([]byte, 0) + } + + var btcKey1, btcKey2 route.Vertex + copy(btcKey1[:], dbChan.BitcoinKey1) + copy(btcKey2[:], dbChan.BitcoinKey2) + + channel := &models.ChannelEdgeInfo{ + ChainHash: chain, + ChannelID: byteOrder.Uint64(dbChan.Scid), + NodeKey1Bytes: node1, + NodeKey2Bytes: node2, + BitcoinKey1Bytes: btcKey1, + BitcoinKey2Bytes: btcKey2, + ChannelPoint: *op, + Capacity: btcutil.Amount(dbChan.Capacity.Int64), + Features: featureBuf.Bytes(), + ExtraOpaqueData: recs, + } + + if dbChan.Bitcoin1Signature != nil { + channel.AuthProof = &models.ChannelAuthProof{ + NodeSig1Bytes: dbChan.Node1Signature, + NodeSig2Bytes: dbChan.Node2Signature, + BitcoinSig1Bytes: dbChan.Bitcoin1Signature, + BitcoinSig2Bytes: dbChan.Bitcoin2Signature, + } + } + + return channel, nil +} + +// buildNodeVertices is a helper that converts raw node public keys +// into route.Vertex instances. +func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex, + route.Vertex, error) { + + node1Vertex, err := route.NewVertexFromBytes(node1Pub) + if err != nil { + return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+ + "create vertex from node1 pubkey: %w", err) + } + + node2Vertex, err := route.NewVertexFromBytes(node2Pub) + if err != nil { + return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+ + "create vertex from node2 pubkey: %w", err) + } + + return node1Vertex, node2Vertex, nil +} + +// getChanFeaturesAndExtras fetches the channel features and extra TLV types +// for a channel with the given ID. +func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries, + id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) { + + rows, err := db.GetChannelFeaturesAndExtras(ctx, id) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch channel "+ + "features and extras: %w", err) + } + + var ( + fv = lnwire.EmptyFeatureVector() + extras = make(map[uint64][]byte) + ) + for _, row := range rows { + if row.IsFeature { + fv.Set(lnwire.FeatureBit(row.FeatureBit)) + + continue + } + + tlvType, ok := row.ExtraKey.(int64) + if !ok { + return nil, nil, fmt.Errorf("unexpected type for "+ + "TLV type: %T", row.ExtraKey) + } + + valueBytes, ok := row.Value.([]byte) + if !ok { + return nil, nil, fmt.Errorf("unexpected type for "+ + "Value: %T", row.Value) + } + + extras[uint64(tlvType)] = valueBytes + } + + return fv, extras, nil +} + +// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves +// all the extra info required to build the complete models.ChannelEdgePolicy +// types. It returns two policies, which may be nil if the provided +// sqlc.ChannelPolicy records are nil. +func getAndBuildChanPolicies(ctx context.Context, db SQLQueries, + dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1, + node2 route.Vertex) (*models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + if dbPol1 == nil && dbPol2 == nil { + return nil, nil, nil + } + + var ( + policy1ID int64 + policy2ID int64 + ) + if dbPol1 != nil { + policy1ID = dbPol1.ID + } + if dbPol2 != nil { + policy2ID = dbPol2.ID + } + rows, err := db.GetChannelPolicyExtraTypes( + ctx, sqlc.GetChannelPolicyExtraTypesParams{ + ID: policy1ID, + ID_2: policy2ID, + }, + ) + if err != nil { + return nil, nil, err + } + + var ( + dbPol1Extras = make(map[uint64][]byte) + dbPol2Extras = make(map[uint64][]byte) + ) + for _, row := range rows { + switch row.PolicyID { + case policy1ID: + dbPol1Extras[uint64(row.Type)] = row.Value + case policy2ID: + dbPol2Extras[uint64(row.Type)] = row.Value + default: + return nil, nil, fmt.Errorf("unexpected policy ID %d "+ + "in row: %v", row.PolicyID, row) + } + } + + var pol1, pol2 *models.ChannelEdgePolicy + if dbPol1 != nil { + pol1, err = buildChanPolicy( + *dbPol1, channelID, dbPol1Extras, node2, true, + ) + if err != nil { + return nil, nil, err + } + } + if dbPol2 != nil { + pol2, err = buildChanPolicy( + *dbPol2, channelID, dbPol2Extras, node1, false, + ) + if err != nil { + return nil, nil, err + } + } + + return pol1, pol2, nil +} + +// buildChanPolicy builds a models.ChannelEdgePolicy instance from the +// provided sqlc.ChannelPolicy and other required information. +func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64, + extras map[uint64][]byte, toNode route.Vertex, + isNode1 bool) (*models.ChannelEdgePolicy, error) { + + recs, err := lnwire.CustomRecords(extras).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + + var msgFlags lnwire.ChanUpdateMsgFlags + if dbPolicy.MaxHtlcMsat.Valid { + msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc + } + + var chanFlags lnwire.ChanUpdateChanFlags + if !isNode1 { + chanFlags |= lnwire.ChanUpdateDirection + } + if dbPolicy.Disabled.Bool { + chanFlags |= lnwire.ChanUpdateDisabled + } + + var inboundFee fn.Option[lnwire.Fee] + if dbPolicy.InboundFeeRateMilliMsat.Valid || + dbPolicy.InboundBaseFeeMsat.Valid { + + inboundFee = fn.Some(lnwire.Fee{ + BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64), + FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64), + }) + } + + return &models.ChannelEdgePolicy{ + SigBytes: dbPolicy.Signature, + ChannelID: channelID, + LastUpdate: time.Unix( + dbPolicy.LastUpdate.Int64, 0, + ), + MessageFlags: msgFlags, + ChannelFlags: chanFlags, + TimeLockDelta: uint16(dbPolicy.Timelock), + MinHTLC: lnwire.MilliSatoshi( + dbPolicy.MinHtlcMsat, + ), + MaxHTLC: lnwire.MilliSatoshi( + dbPolicy.MaxHtlcMsat.Int64, + ), + FeeBaseMSat: lnwire.MilliSatoshi( + dbPolicy.BaseFeeMsat, + ), + FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm), + ToNode: toNode, + InboundFee: inboundFee, + ExtraOpaqueData: recs, + }, nil +} + +// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give +// row which is expected to be a sqlc type that contains channel policy +// information. It returns two policies, which may be nil if the policy +// information is not present in the row. +// +//nolint:ll +func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy, + error) { + + var policy1, policy2 *sqlc.ChannelPolicy + switch r := row.(type) { + case sqlc.ListChannelsByNodeIDRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.ChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.ChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + default: + return nil, nil, fmt.Errorf("unexpected row type in "+ + "extractChannelPolicies: %T", r) + } +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 82385008d..884d5c340 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -257,6 +257,120 @@ func (q *Queries) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDPara return i, err } +const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many +SELECT + cf.channel_id, + true AS is_feature, + cf.feature_bit AS feature_bit, + NULL AS extra_key, + NULL AS value +FROM channel_features cf +WHERE cf.channel_id = $1 + +UNION ALL + +SELECT + cet.channel_id, + false AS is_feature, + 0 AS feature_bit, + cet.type AS extra_key, + cet.value AS value +FROM channel_extra_types cet +WHERE cet.channel_id = $1 +` + +type GetChannelFeaturesAndExtrasRow struct { + ChannelID int64 + IsFeature bool + FeatureBit int32 + ExtraKey interface{} + Value interface{} +} + +func (q *Queries) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelFeaturesAndExtras, channelID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelFeaturesAndExtrasRow + for rows.Next() { + var i GetChannelFeaturesAndExtrasRow + if err := rows.Scan( + &i.ChannelID, + &i.IsFeature, + &i.FeatureBit, + &i.ExtraKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelPolicyExtraTypes = `-- name: GetChannelPolicyExtraTypes :many +SELECT + cp.id AS policy_id, + cp.channel_id, + cp.node_id, + cpet.type, + cpet.value +FROM channel_policies cp +JOIN channel_policy_extra_types cpet +ON cp.id = cpet.channel_policy_id +WHERE cp.id = $1 OR cp.id = $2 +` + +type GetChannelPolicyExtraTypesParams struct { + ID int64 + ID_2 int64 +} + +type GetChannelPolicyExtraTypesRow struct { + PolicyID int64 + ChannelID int64 + NodeID int64 + Type int64 + Value []byte +} + +func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelPolicyExtraTypes, arg.ID, arg.ID_2) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelPolicyExtraTypesRow + for rows.Next() { + var i GetChannelPolicyExtraTypesRow + if err := rows.Scan( + &i.PolicyID, + &i.ChannelID, + &i.NodeID, + &i.Type, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM node_extra_types @@ -614,6 +728,158 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa return err } +const listChannelsByNodeID = `-- name: ListChannelsByNodeID :many + +SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 = $2 OR c.node_id_2 = $2) +` + +type ListChannelsByNodeIDParams struct { + Version int16 + NodeID1 int64 +} + +type ListChannelsByNodeIDRow struct { + Channel Channel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2Signature []byte +} + +func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) { + rows, err := q.db.QueryContext(ctx, listChannelsByNodeID, arg.Version, arg.NodeID1) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListChannelsByNodeIDRow + for rows.Next() { + var i ListChannelsByNodeIDRow + if err := rows.Scan( + &i.Channel.ID, + &i.Channel.Version, + &i.Channel.Scid, + &i.Channel.NodeID1, + &i.Channel.NodeID2, + &i.Channel.Outpoint, + &i.Channel.Capacity, + &i.Channel.BitcoinKey1, + &i.Channel.BitcoinKey2, + &i.Channel.Node1Signature, + &i.Channel.Node2Signature, + &i.Channel.Bitcoin1Signature, + &i.Channel.Bitcoin2Signature, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one /* ───────────────────────────────────────────── channel_policies table queries diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index e7b225204..d4a6a9dd0 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -29,6 +29,8 @@ type Querier interface { GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) + GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) + GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields @@ -61,6 +63,7 @@ type Querier interface { InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error + ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 85000ec8f..c6dbb6e6b 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -165,6 +165,27 @@ FROM channels c WHERE c.scid = $1 AND c.version = $2; +-- name: GetChannelFeaturesAndExtras :many +SELECT + cf.channel_id, + true AS is_feature, + cf.feature_bit AS feature_bit, + NULL AS extra_key, + NULL AS value +FROM channel_features cf +WHERE cf.channel_id = $1 + +UNION ALL + +SELECT + cet.channel_id, + false AS is_feature, + 0 AS feature_bit, + cet.type AS extra_key, + cet.value AS value +FROM channel_extra_types cet +WHERE cet.channel_id = $1; + -- name: HighestSCID :one SELECT scid FROM channels @@ -172,6 +193,55 @@ WHERE version = $1 ORDER BY scid DESC LIMIT 1; +-- name: ListChannelsByNodeID :many + +SELECT sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + -- TODO(elle): use sqlc.embed to embed policy structs + -- once this issue is resolved: + -- https://github.com/sqlc-dev/sqlc/issues/2997 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND (c.node_id_1 = $2 OR c.node_id_2 = $2); + /* ───────────────────────────────────────────── channel_features table queries ───────────────────────────────────────────── @@ -237,6 +307,18 @@ INSERT INTO channel_policy_extra_types ( ) VALUES ($1, $2, $3); +-- name: GetChannelPolicyExtraTypes :many +SELECT + cp.id AS policy_id, + cp.channel_id, + cp.node_id, + cpet.type, + cpet.value +FROM channel_policies cp +JOIN channel_policy_extra_types cpet +ON cp.id = cpet.channel_policy_id +WHERE cp.id = $1 OR cp.id = $2; + -- name: DeleteChannelPolicyExtraTypes :exec DELETE FROM channel_policy_extra_types WHERE channel_policy_id = $1;