From 39e521e12b84a594afbfb09f4ddebcf8c06afec2 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 16:23:36 +0200 Subject: [PATCH] graph/db+sqldb: implement ChanUpdatesInHorizon Add `ChanUpdatesInHorizon` method to the SQLStore. This lets us run `TestChanUpdatesInHorizon` against our SQL backends. --- graph/db/graph_test.go | 2 +- graph/db/sql_store.go | 192 +++++++++++++++++++++++++++++++++-- sqldb/sqlc/graph.sql.go | 172 +++++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 56 ++++++++++ 5 files changed, 413 insertions(+), 10 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index dbd828527..cfbc6ed65 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1937,7 +1937,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() ctx := context.Background() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 564ce96c6..4db4e7662 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -90,6 +90,7 @@ type SQLQueries interface { GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) + GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error @@ -924,6 +925,125 @@ func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex, }, sqldb.NoOpReset) } +// ChanUpdatesInHorizon returns all the known channel edges which have at least +// one edge that has an update timestamp within the specified horizon. +// +// NOTE: This is part of the V1Store interface. +func (s *SQLStore) ChanUpdatesInHorizon(startTime, + endTime time.Time) ([]ChannelEdge, error) { + + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + + var ( + ctx = context.TODO() + // To ensure we don't return duplicate ChannelEdges, we'll use + // an additional map to keep track of the edges already seen to + // prevent re-adding it. + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + edges []ChannelEdge + hits int + ) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + rows, err := db.GetChannelsByPolicyLastUpdateRange( + ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{ + Version: int16(ProtocolV1), + StartTime: sqldb.SQLInt64(startTime.Unix()), + EndTime: sqldb.SQLInt64(endTime.Unix()), + }, + ) + if err != nil { + return err + } + + for _, row := range rows { + // If we've already retrieved the info and policies for + // this edge, then we can skip it as we don't need to do + // so again. + chanIDInt := byteOrder.Uint64(row.Channel.Scid) + if _, ok := edgesSeen[chanIDInt]; ok { + continue + } + + if channel, ok := s.chanCache.get(chanIDInt); ok { + hits++ + edgesSeen[chanIDInt] = struct{}{} + edges = append(edges, channel) + + continue + } + + node1, node2, err := buildNodes( + ctx, db, row.Node, row.Node_2, + ) + if err != nil { + return err + } + + channel, err := getAndBuildEdgeInfo( + ctx, db, s.cfg.ChainHash, row.Channel.ID, + row.Channel, node1.PubKeyBytes, + node2.PubKeyBytes, + ) + if err != nil { + return fmt.Errorf("unable to build channel "+ + "info: %w", err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + p1, p2, err := getAndBuildChanPolicies( + ctx, db, dbPol1, dbPol2, channel.ChannelID, + node1.PubKeyBytes, node2.PubKeyBytes, + ) + if err != nil { + return fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + edgesSeen[chanIDInt] = struct{}{} + chanEdge := ChannelEdge{ + Info: channel, + Policy1: p1, + Policy2: p2, + Node1: node1, + Node2: node2, + } + edges = append(edges, chanEdge) + edgesToCache[chanIDInt] = chanEdge + } + + return nil + }, func() { + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + edges = nil + }) + if err != nil { + return nil, fmt.Errorf("unable to fetch channels: %w", err) + } + + // Insert any edges loaded from disk into the cache. + for chanid, channel := range edgesToCache { + s.chanCache.insert(chanid, channel) + } + + if len(edges) > 0 { + log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)", + float64(hits)/float64(len(edges)), hits, len(edges)) + } else { + log.Debugf("ChanUpdatesInHorizon returned no edges in "+ + "horizon (%s, %s)", startTime, endTime) + } + + return edges, nil +} + // forEachNodeDirectedChannel iterates through all channels of a given // node, executing the passed callback on the directed edge representing the // channel and its incoming policy. If the node is not found, no error is @@ -977,12 +1097,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries, err) } - edge, err := buildCacheableChannelInfo( - row.Channel, node1, node2, - ) - if err != nil { - return err - } + edge := buildCacheableChannelInfo(row.Channel, node1, node2) dbPol1, dbPol2, err := extractChannelPolicies(row) if err != nil { @@ -1286,14 +1401,14 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries, // provided database channel row and the public keys of the two nodes // involved in the channel. func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub, - node2Pub route.Vertex) (*models.CachedEdgeInfo, error) { + node2Pub route.Vertex) *models.CachedEdgeInfo { return &models.CachedEdgeInfo{ ChannelID: byteOrder.Uint64(dbChan.Scid), NodeKey1Bytes: node1Pub, NodeKey2Bytes: node2Pub, Capacity: btcutil.Amount(dbChan.Capacity.Int64), - }, nil + } } // buildNode constructs a LightningNode instance from the given database node @@ -2302,17 +2417,76 @@ func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64, }, nil } +// buildNodes builds the models.LightningNode instances for the +// given row which is expected to be a sqlc type that contains node information. +func buildNodes(ctx context.Context, db SQLQueries, dbNode1, + dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode, + error) { + + node1, err := buildNode(ctx, db, &dbNode1) + if err != nil { + return nil, nil, err + } + + node2, err := buildNode(ctx, db, &dbNode2) + if err != nil { + return nil, nil, err + } + + return node1, node2, nil +} + // extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give // row which is expected to be a sqlc type that contains channel policy // information. It returns two policies, which may be nil if the policy // information is not present in the row. // -//nolint:ll +//nolint:ll,dupl func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy, error) { var policy1, policy2 *sqlc.ChannelPolicy switch r := row.(type) { + case sqlc.GetChannelsByPolicyLastUpdateRangeRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.ChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.ChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.ListChannelsByNodeIDRow: if r.Policy1ID.Valid { policy1 = &sqlc.ChannelPolicy{ diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 606d40217..9c0282dc1 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -371,6 +371,178 @@ func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannel return items, nil } +const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, + + -- Policy 1 (node_id_1) + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 (node_id_2) + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = $1 + AND ( + (cp1.last_update >= $2 AND cp1.last_update < $3) + OR + (cp2.last_update >= $2 AND cp2.last_update < $3) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) + THEN COALESCE(cp1.last_update, 0) + ELSE COALESCE(cp2.last_update, 0) + END ASC +` + +type GetChannelsByPolicyLastUpdateRangeParams struct { + Version int16 + StartTime sql.NullInt64 + EndTime sql.NullInt64 +} + +type GetChannelsByPolicyLastUpdateRangeRow struct { + Channel Channel + Node Node + Node_2 Node + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2Signature []byte +} + +func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelsByPolicyLastUpdateRange, arg.Version, arg.StartTime, arg.EndTime) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByPolicyLastUpdateRangeRow + for rows.Next() { + var i GetChannelsByPolicyLastUpdateRangeRow + if err := rows.Scan( + &i.Channel.ID, + &i.Channel.Version, + &i.Channel.Scid, + &i.Channel.NodeID1, + &i.Channel.NodeID2, + &i.Channel.Outpoint, + &i.Channel.Capacity, + &i.Channel.BitcoinKey1, + &i.Channel.BitcoinKey2, + &i.Channel.Node1Signature, + &i.Channel.Node2Signature, + &i.Channel.Bitcoin1Signature, + &i.Channel.Bitcoin2Signature, + &i.Node.ID, + &i.Node.Version, + &i.Node.PubKey, + &i.Node.Alias, + &i.Node.LastUpdate, + &i.Node.Color, + &i.Node.Signature, + &i.Node_2.ID, + &i.Node_2.Version, + &i.Node_2.PubKey, + &i.Node_2.Alias, + &i.Node_2.LastUpdate, + &i.Node_2.Color, + &i.Node_2.Signature, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM node_extra_types diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index c0abde3c2..1c057f60e 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -31,6 +31,7 @@ type Querier interface { GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) + GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 510195906..eb77e7e3b 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -206,6 +206,62 @@ SELECT FROM channel_extra_types cet WHERE cet.channel_id = $1; +-- name: GetChannelsByPolicyLastUpdateRange :many +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 (node_id_1) + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.signature AS policy1_signature, + + -- Policy 2 (node_id_2) + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.signature AS policy2_signature + +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = @version + AND ( + (cp1.last_update >= @start_time AND cp1.last_update < @end_time) + OR + (cp2.last_update >= @start_time AND cp2.last_update < @end_time) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) + THEN COALESCE(cp1.last_update, 0) + ELSE COALESCE(cp2.last_update, 0) + END ASC; + -- name: HighestSCID :one SELECT scid FROM channels