diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index d6d7c54ad..cd960b166 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1615,11 +1615,12 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, } var ( - ctx = context.TODO() - deleted []*models.ChannelEdgeInfo + ctx = context.TODO() + edges []*models.ChannelEdgeInfo ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - chanIDsToDelete := make([]int64, 0, len(chanIDs)) + // First, collect all channel rows. + var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow chanCallBack := func(ctx context.Context, row sqlc.GetChannelsBySCIDWithPoliciesRow) error { @@ -1628,65 +1629,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, scid := byteOrder.Uint64(row.GraphChannel.Scid) delete(chanLookup, scid) - node1, node2, err := buildNodeVertices( - row.GraphNode.PubKey, row.GraphNode_2.PubKey, - ) - if err != nil { - return err - } - - info, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, - ) - if err != nil { - return err - } - - deleted = append(deleted, info) - chanIDsToDelete = append( - chanIDsToDelete, row.GraphChannel.ID, - ) - - if !markZombie { - return nil - } - - nodeKey1, nodeKey2 := info.NodeKey1Bytes, - info.NodeKey2Bytes - if strictZombiePruning { - var e1UpdateTime, e2UpdateTime *time.Time - if row.Policy1LastUpdate.Valid { - e1Time := time.Unix( - row.Policy1LastUpdate.Int64, 0, - ) - e1UpdateTime = &e1Time - } - if row.Policy2LastUpdate.Valid { - e2Time := time.Unix( - row.Policy2LastUpdate.Int64, 0, - ) - e2UpdateTime = &e2Time - } - - nodeKey1, nodeKey2 = makeZombiePubkeys( - info.NodeKey1Bytes, info.NodeKey2Bytes, - e1UpdateTime, e2UpdateTime, - ) - } - - err = db.UpsertZombieChannel( - ctx, sqlc.UpsertZombieChannelParams{ - Version: int16(ProtocolV1), - Scid: channelIDToBytes(scid), - NodeKey1: nodeKey1[:], - NodeKey2: nodeKey2[:], - }, - ) - if err != nil { - return fmt.Errorf("unable to mark channel as "+ - "zombie: %w", err) - } + channelRows = append(channelRows, row) return nil } @@ -1702,9 +1645,37 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return ErrEdgeNotFound } + if len(channelRows) == 0 { + return nil + } + + // Batch build all channel edges. + var chanIDsToDelete []int64 + edges, chanIDsToDelete, err = batchBuildChannelInfo( + ctx, s.cfg, db, channelRows, + ) + if err != nil { + return err + } + + if markZombie { + for i, row := range channelRows { + scid := byteOrder.Uint64(row.GraphChannel.Scid) + + err := handleZombieMarking( + ctx, db, row, edges[i], + strictZombiePruning, scid, + ) + if err != nil { + return fmt.Errorf("unable to mark "+ + "channel as zombie: %w", err) + } + } + } + return s.deleteChannels(ctx, db, chanIDsToDelete) }, func() { - deleted = nil + edges = nil // Re-fill the lookup map. for _, chanID := range chanIDs { @@ -1721,7 +1692,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, s.chanCache.remove(chanID) } - return deleted, nil + return edges, nil } // FetchChannelEdgesByID attempts to lookup the two directed edges for the @@ -5406,3 +5377,90 @@ func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context, return edges, nil } + +// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo +// instances from the provided rows using batch loading for channel data. +func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context, + cfg *SQLStoreConfig, db SQLQueries, rows []T) ( + []*models.ChannelEdgeInfo, []int64, error) { + + if len(rows) == 0 { + return nil, nil, nil + } + + // Collect all the channel IDs needed for batch loading. + channelIDs := make([]int64, len(rows)) + for i, row := range rows { + channelIDs[i] = row.Channel().ID + } + + // Batch load the channel data. + channelBatchData, err := batchLoadChannelData( + ctx, cfg.QueryCfg, db, channelIDs, nil, + ) + if err != nil { + return nil, nil, fmt.Errorf("unable to batch load channel "+ + "data: %w", err) + } + + // Build all channel edges using batch data. + edges := make([]*models.ChannelEdgeInfo, 0, len(rows)) + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1Pub(), row.Node2Pub(), + ) + if err != nil { + return nil, nil, err + } + + // Build channel info using batch data + info, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.Channel(), node1, node2, + channelBatchData, + ) + if err != nil { + return nil, nil, err + } + + edges = append(edges, info) + } + + return edges, channelIDs, nil +} + +// handleZombieMarking is a helper function that handles the logic of +// marking a channel as a zombie in the database. It takes into account whether +// we are in strict zombie pruning mode, and adjusts the node public keys +// accordingly based on the last update timestamps of the channel policies. +func handleZombieMarking(ctx context.Context, db SQLQueries, + row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo, + strictZombiePruning bool, scid uint64) error { + + nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes + + if strictZombiePruning { + var e1UpdateTime, e2UpdateTime *time.Time + if row.Policy1LastUpdate.Valid { + e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0) + e1UpdateTime = &e1Time + } + if row.Policy2LastUpdate.Valid { + e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0) + e2UpdateTime = &e2Time + } + + nodeKey1, nodeKey2 = makeZombiePubkeys( + info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime, + e2UpdateTime, + ) + } + + return db.UpsertZombieChannel( + ctx, sqlc.UpsertZombieChannelParams{ + Version: int16(ProtocolV1), + Scid: channelIDToBytes(scid), + NodeKey1: nodeKey1[:], + NodeKey2: nodeKey2[:], + }, + ) +} diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index 64440a262..823002817 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -71,3 +71,37 @@ func (r GetChannelsByPolicyLastUpdateRangeRow) Node1() GraphNode { func (r GetChannelsByPolicyLastUpdateRangeRow) Node2() GraphNode { return r.GraphNode_2 } + +// ChannelAndNodeIDs is an interface that provides access to a channel and its +// two node public keys. +type ChannelAndNodeIDs interface { + // Channel returns the GraphChannel associated with this interface. + Channel() GraphChannel + + // Node1Pub returns the public key of the first node as a byte slice. + Node1Pub() []byte + + // Node2Pub returns the public key of the second node as a byte slice. + Node2Pub() []byte +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1Pub returns the public key of the first node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node1Pub() []byte { + return r.GraphNode.PubKey +} + +// Node2Pub returns the public key of the second node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node2Pub() []byte { + return r.GraphNode_2.PubKey +}