graph/db+sqldb: impl PruneGraph, PruneTip, ChannelView

Which lets us run `TestGraphPruning` and `TestBatchedAddChannelEdge`
against our SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 17:46:37 +02:00
parent 2da701cc4f
commit 9dd0361ed0
5 changed files with 391 additions and 2 deletions

View File

@@ -1685,7 +1685,7 @@ func TestGraphPruning(t *testing.T) {
t.Parallel()
ctx := context.Background()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
sourceNode := createTestVertex(t)
if err := graph.SetSourceNode(ctx, sourceNode); err != nil {
@@ -4006,7 +4006,7 @@ func TestBatchedAddChannelEdge(t *testing.T) {
t.Parallel()
ctx := context.Background()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
sourceNode := createTestVertex(t)
require.Nil(t, graph.SetSourceNode(ctx, sourceNode))

View File

@@ -91,12 +91,14 @@ type SQLQueries interface {
*/
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
@@ -125,6 +127,12 @@ type SQLQueries interface {
CountZombieChannels(ctx context.Context, version int16) (int64, error)
DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
/*
Prune log table queries.
*/
GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
}
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
@@ -2230,6 +2238,213 @@ func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
return prunedNodes, nil
}
// PruneGraph prunes newly closed channels from the channel graph in response
// to a new block being solved on the network. Any transactions which spend the
// funding output of any known channels within he graph will be deleted.
// Additionally, the "prune tip", or the last block which has been used to
// prune the graph is stored so callers can ensure the graph is fully in sync
// with the current UTXO state. A slice of channels that have been closed by
// the target block along with any pruned nodes are returned if the function
// succeeds without error.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
blockHash *chainhash.Hash, blockHeight uint32) (
[]*models.ChannelEdgeInfo, []route.Vertex, error) {
ctx := context.TODO()
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
var (
closedChans []*models.ChannelEdgeInfo
prunedNodes []route.Vertex
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
for _, outpoint := range spentOutputs {
// TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and
// Postgres.
//
// NOTE: this fetches channels for all protocol
// versions.
row, err := db.GetChannelByOutpoint(
ctx, outpoint.String(),
)
if errors.Is(err, sql.ErrNoRows) {
continue
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w",
err)
}
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}
info, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID,
row.Channel, node1, node2,
)
if err != nil {
return err
}
err = db.DeleteChannel(ctx, row.Channel.ID)
if err != nil {
return fmt.Errorf("unable to delete "+
"channel: %w", err)
}
closedChans = append(closedChans, info)
}
err := db.UpsertPruneLogEntry(
ctx, sqlc.UpsertPruneLogEntryParams{
BlockHash: blockHash[:],
BlockHeight: int64(blockHeight),
},
)
if err != nil {
return fmt.Errorf("unable to insert prune log "+
"entry: %w", err)
}
// Now that we've pruned some channels, we'll also prune any
// nodes that no longer have any channels.
prunedNodes, err = s.pruneGraphNodes(ctx, db)
if err != nil {
return fmt.Errorf("unable to prune graph nodes: %w",
err)
}
return nil
}, func() {
prunedNodes = nil
closedChans = nil
})
if err != nil {
return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
}
for _, channel := range closedChans {
s.rejectCache.remove(channel.ChannelID)
s.chanCache.remove(channel.ChannelID)
}
return closedChans, prunedNodes, nil
}
// ChannelView returns the verifiable edge information for each active channel
// within the known channel graph. The set of UTXOs (along with their scripts)
// returned are the ones that need to be watched on chain to detect channel
// closes on the resident blockchain.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
var (
ctx = context.TODO()
edgePoints []EdgePoint
)
handleChannel := func(db SQLQueries,
channel sqlc.ListChannelsPaginatedRow) error {
pkScript, err := genMultiSigP2WSH(
channel.BitcoinKey1, channel.BitcoinKey2,
)
if err != nil {
return err
}
op, err := wire.NewOutPointFromString(channel.Outpoint)
if err != nil {
return err
}
edgePoints = append(edgePoints, EdgePoint{
FundingPkScript: pkScript,
OutPoint: *op,
})
return nil
}
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
lastID := int64(-1)
for {
rows, err := db.ListChannelsPaginated(
ctx, sqlc.ListChannelsPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: pageSize,
},
)
if err != nil {
return err
}
if len(rows) == 0 {
break
}
for _, row := range rows {
err := handleChannel(db, row)
if err != nil {
return err
}
lastID = row.ID
}
}
return nil
}, func() {
edgePoints = nil
})
if err != nil {
return nil, fmt.Errorf("unable to fetch channel view: %w", err)
}
return edgePoints, nil
}
// PruneTip returns the block height and hash of the latest block that has been
// used to prune channels in the graph. Knowing the "prune tip" allows callers
// to tell if the graph is currently in sync with the current best known UTXO
// state.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
var (
ctx = context.TODO()
tipHash chainhash.Hash
tipHeight uint32
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
pruneTip, err := db.GetPruneTip(ctx)
if errors.Is(err, sql.ErrNoRows) {
return ErrGraphNeverPruned
} else if err != nil {
return fmt.Errorf("unable to fetch prune tip: %w", err)
}
tipHash = chainhash.Hash(pruneTip.BlockHash)
tipHeight = uint32(pruneTip.BlockHeight)
return nil
}, sqldb.NoOpReset)
if err != nil {
return nil, 0, err
}
return &tipHash, tipHeight, nil
}
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
//
// NOTE: this prunes nodes across protocol versions. It will never prune the
@@ -3389,6 +3604,11 @@ func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
if dbChan.Version != int16(ProtocolV1) {
return nil, fmt.Errorf("unsupported channel version: %d",
dbChan.Version)
}
fv, extras, err := getChanFeaturesAndExtras(
ctx, db, dbChanID,
)

View File

@@ -273,6 +273,46 @@ func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAn
return i, err
}
const getChannelByOutpoint = `-- name: GetChannelByOutpoint :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE c.outpoint = $1
`
type GetChannelByOutpointRow struct {
Channel Channel
Node1Pubkey []byte
Node2Pubkey []byte
}
func (q *Queries) GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error) {
row := q.db.QueryRowContext(ctx, getChannelByOutpoint, outpoint)
var i GetChannelByOutpointRow
err := row.Scan(
&i.Channel.ID,
&i.Channel.Version,
&i.Channel.Scid,
&i.Channel.NodeID1,
&i.Channel.NodeID2,
&i.Channel.Outpoint,
&i.Channel.Capacity,
&i.Channel.BitcoinKey1,
&i.Channel.BitcoinKey2,
&i.Channel.Node1Signature,
&i.Channel.Node2Signature,
&i.Channel.Bitcoin1Signature,
&i.Channel.Bitcoin2Signature,
&i.Node1Pubkey,
&i.Node2Pubkey,
)
return i, err
}
const getChannelByOutpointWithPolicies = `-- name: GetChannelByOutpointWithPolicies :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
@@ -1127,6 +1167,20 @@ func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByL
return items, nil
}
const getPruneTip = `-- name: GetPruneTip :one
SELECT block_height, block_hash
FROM prune_log
ORDER BY block_height DESC
LIMIT 1
`
func (q *Queries) GetPruneTip(ctx context.Context) (PruneLog, error) {
row := q.db.QueryRowContext(ctx, getPruneTip)
var i PruneLog
err := row.Scan(&i.BlockHeight, &i.BlockHash)
return i, err
}
const getPublicV1ChannelsBySCID = `-- name: GetPublicV1ChannelsBySCID :many
SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature
FROM channels
@@ -1650,6 +1704,55 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo
return items, nil
}
const listChannelsPaginated = `-- name: ListChannelsPaginated :many
SELECT id, bitcoin_key_1, bitcoin_key_2, outpoint
FROM channels c
WHERE c.version = $1 AND c.id > $2
ORDER BY c.id
LIMIT $3
`
type ListChannelsPaginatedParams struct {
Version int16
ID int64
Limit int32
}
type ListChannelsPaginatedRow struct {
ID int64
BitcoinKey1 []byte
BitcoinKey2 []byte
Outpoint string
}
func (q *Queries) ListChannelsPaginated(ctx context.Context, arg ListChannelsPaginatedParams) ([]ListChannelsPaginatedRow, error) {
rows, err := q.db.QueryContext(ctx, listChannelsPaginated, arg.Version, arg.ID, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListChannelsPaginatedRow
for rows.Next() {
var i ListChannelsPaginatedRow
if err := rows.Scan(
&i.ID,
&i.BitcoinKey1,
&i.BitcoinKey2,
&i.Outpoint,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listChannelsWithPoliciesPaginated = `-- name: ListChannelsWithPoliciesPaginated :many
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
@@ -2033,6 +2136,31 @@ func (q *Queries) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTy
return err
}
const upsertPruneLogEntry = `-- name: UpsertPruneLogEntry :exec
/* ─────────────────────────────────────────────
prune_log table queries
─────────────────────────────────────────────
*/
INSERT INTO prune_log (
block_height, block_hash
) VALUES (
$1, $2
)
ON CONFLICT(block_height) DO UPDATE SET
block_hash = EXCLUDED.block_hash
`
type UpsertPruneLogEntryParams struct {
BlockHeight int64
BlockHash []byte
}
func (q *Queries) UpsertPruneLogEntry(ctx context.Context, arg UpsertPruneLogEntryParams) error {
_, err := q.db.ExecContext(ctx, upsertPruneLogEntry, arg.BlockHeight, arg.BlockHash)
return err
}
const upsertZombieChannel = `-- name: UpsertZombieChannel :exec
/* ─────────────────────────────────────────────
zombie_channels table queries

View File

@@ -32,6 +32,7 @@ type Querier interface {
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error)
@@ -58,6 +59,7 @@ type Querier interface {
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
GetPruneTip(ctx context.Context) (PruneLog, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error)
GetSCIDByOutpoint(ctx context.Context, arg GetSCIDByOutpointParams) ([]byte, error)
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
@@ -86,6 +88,7 @@ type Querier interface {
IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error)
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
ListChannelsPaginated(ctx context.Context, arg ListChannelsPaginatedParams) ([]ListChannelsPaginatedRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error)
ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error)
ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error)
@@ -108,6 +111,7 @@ type Querier interface {
UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error)
UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error)
UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error
UpsertPruneLogEntry(ctx context.Context, arg UpsertPruneLogEntryParams) error
UpsertZombieChannel(ctx context.Context, arg UpsertZombieChannelParams) error
}

View File

@@ -212,6 +212,16 @@ RETURNING id;
SELECT * FROM channels
WHERE scid = $1 AND version = $2;
-- name: GetChannelByOutpoint :one
SELECT
sqlc.embed(c),
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE c.outpoint = $1;
-- name: GetChannelAndNodesBySCID :one
SELECT
c.*,
@@ -411,6 +421,13 @@ WHERE node_1_signature IS NOT NULL
AND scid >= @start_scid
AND scid < @end_scid;
-- name: ListChannelsPaginated :many
SELECT id, bitcoin_key_1, bitcoin_key_2, outpoint
FROM channels c
WHERE c.version = $1 AND c.id > $2
ORDER BY c.id
LIMIT $3;
-- name: ListChannelsWithPoliciesPaginated :many
SELECT
sqlc.embed(c),
@@ -648,3 +665,23 @@ SELECT EXISTS (
WHERE scid = $1
AND version = $2
) AS is_zombie;
/* ─────────────────────────────────────────────
prune_log table queries
─────────────────────────────────────────────
*/
-- name: UpsertPruneLogEntry :exec
INSERT INTO prune_log (
block_height, block_hash
) VALUES (
$1, $2
)
ON CONFLICT(block_height) DO UPDATE SET
block_hash = EXCLUDED.block_hash;
-- name: GetPruneTip :one
SELECT block_height, block_hash
FROM prune_log
ORDER BY block_height DESC
LIMIT 1;