graph/db+sqldb: pass set of outpoints to SQL

This commit adds a new GetChannelsByOutpoints query which takes a slice
of outpoint strings. This lets us then update PruneGraph to use
paginated calls to GetChannelsByOutpoints instead of making one DB call
per outpoint.
This commit is contained in:
Elle Mouton
2025-07-16 08:33:29 +02:00
parent 2fa30e8735
commit f72c48b283
4 changed files with 44 additions and 69 deletions

View File

@@ -93,7 +93,7 @@ type SQLQueries interface {
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
@@ -2365,22 +2365,9 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
prunedNodes []route.Vertex
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
for _, outpoint := range spentOutputs {
// TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and
// Postgres.
//
// NOTE: this fetches channels for all protocol
// versions.
row, err := db.GetChannelByOutpoint(
ctx, outpoint.String(),
)
if errors.Is(err, sql.ErrNoRows) {
continue
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w",
err)
}
// Define the callback function for processing each channel.
channelCallback := func(ctx context.Context,
row sqlc.GetChannelsByOutpointsRow) error {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
@@ -2404,9 +2391,19 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
}
closedChans = append(closedChans, info)
return nil
}
err := db.UpsertPruneLogEntry(
err := s.forEachChanInOutpoints(
ctx, db, spentOutputs, channelCallback,
)
if err != nil {
return fmt.Errorf("unable to fetch channels by "+
"outpoints: %w", err)
}
err = db.UpsertPruneLogEntry(
ctx, sqlc.UpsertPruneLogEntryParams{
BlockHash: blockHash[:],
BlockHeight: int64(blockHeight),
@@ -2442,6 +2439,35 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
return closedChans, prunedNodes, nil
}
// forEachChanInOutpoints is a helper function that executes a paginated
// query to fetch channels by their outpoints and applies the given call-back
// to each.
//
// NOTE: this fetches channels for all protocol versions.
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
outpoints []*wire.OutPoint, cb func(ctx context.Context,
row sqlc.GetChannelsByOutpointsRow) error) error {
// Create a wrapper that uses the transaction's db instance to execute
// the query.
queryWrapper := func(ctx context.Context,
pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
error) {
return db.GetChannelsByOutpoints(ctx, pageOutpoints)
}
// Define the conversion function from Outpoint to string.
outpointToString := func(outpoint *wire.OutPoint) string {
return outpoint.String()
}
return sqldb.ExecutePagedQuery(
ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
queryWrapper, cb,
)
}
// ChannelView returns the verifiable edge information for each active channel
// within the known channel graph. The set of UTXOs (along with their scripts)
// returned are the ones that need to be watched on chain to detect channel

View File

@@ -358,46 +358,6 @@ func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAn
return i, err
}
const getChannelByOutpoint = `-- name: GetChannelByOutpoint :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey
FROM graph_channels c
JOIN graph_nodes n1 ON c.node_id_1 = n1.id
JOIN graph_nodes n2 ON c.node_id_2 = n2.id
WHERE c.outpoint = $1
`
type GetChannelByOutpointRow struct {
GraphChannel GraphChannel
Node1Pubkey []byte
Node2Pubkey []byte
}
func (q *Queries) GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error) {
row := q.db.QueryRowContext(ctx, getChannelByOutpoint, outpoint)
var i GetChannelByOutpointRow
err := row.Scan(
&i.GraphChannel.ID,
&i.GraphChannel.Version,
&i.GraphChannel.Scid,
&i.GraphChannel.NodeID1,
&i.GraphChannel.NodeID2,
&i.GraphChannel.Outpoint,
&i.GraphChannel.Capacity,
&i.GraphChannel.BitcoinKey1,
&i.GraphChannel.BitcoinKey2,
&i.GraphChannel.Node1Signature,
&i.GraphChannel.Node2Signature,
&i.GraphChannel.Bitcoin1Signature,
&i.GraphChannel.Bitcoin2Signature,
&i.Node1Pubkey,
&i.Node2Pubkey,
)
return i, err
}
const getChannelByOutpointWithPolicies = `-- name: GetChannelByOutpointWithPolicies :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,

View File

@@ -35,7 +35,6 @@ type Querier interface {
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (GraphChannel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error)

View File

@@ -242,16 +242,6 @@ FROM graph_channels c
WHERE c.outpoint IN
(sqlc.slice('outpoints')/*SLICE:outpoints*/);
-- name: GetChannelByOutpoint :one
SELECT
sqlc.embed(c),
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey
FROM graph_channels c
JOIN graph_nodes n1 ON c.node_id_1 = n1.id
JOIN graph_nodes n2 ON c.node_id_2 = n2.id
WHERE c.outpoint = $1;
-- name: GetChannelAndNodesBySCID :one
SELECT
c.*,