graph/db+sqldb: implement various zombie index methods

Here we implement the SQLStore methods:
- MarkEdgeZombie
- MarkEdgeLive
- IsZombieEdge
- NumZombies

These will be tested in the next commit as one more method
implementation is required.
This commit is contained in:
Elle Mouton
2025-06-11 17:04:21 +02:00
parent 137fc09230
commit 00b6e0204c
4 changed files with 282 additions and 0 deletions

View File

@ -108,6 +108,14 @@ type SQLQueries interface {
InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
/*
Zombie index queries.
*/
UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
CountZombieChannels(ctx context.Context, version int16) (int64, error)
DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
} }
// BatchedSQLQueries is a version of SQLQueries that's capable of batched // BatchedSQLQueries is a version of SQLQueries that's capable of batched
@ -1390,6 +1398,160 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
}), nil }), nil
} }
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
// zombie. This method is used on an ad-hoc basis, when channels need to be
// marked as zombies outside the normal pruning cycle.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
pubKey1, pubKey2 [33]byte) error {
ctx := context.TODO()
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
chanIDB := channelIDToBytes(chanID)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
return db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
NodeKey1: pubKey1[:],
NodeKey2: pubKey2[:],
},
)
}, sqldb.NoOpReset)
if err != nil {
return fmt.Errorf("unable to upsert zombie channel "+
"(channel_id=%d): %w", chanID, err)
}
s.rejectCache.remove(chanID)
s.chanCache.remove(chanID)
return nil
}
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
var (
ctx = context.TODO()
chanIDB = channelIDToBytes(chanID)
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.DeleteZombieChannel(
ctx, sqlc.DeleteZombieChannelParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if err != nil {
return fmt.Errorf("unable to delete zombie channel: %w",
err)
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return ErrZombieEdgeNotFound
} else if rows > 1 {
return fmt.Errorf("deleted %d zombie rows, "+
"expected 1", rows)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return fmt.Errorf("unable to mark edge live "+
"(channel_id=%d): %w", chanID, err)
}
s.rejectCache.remove(chanID)
s.chanCache.remove(chanID)
return err
}
// IsZombieEdge returns whether the edge is considered zombie. If it is a
// zombie, then the two node public keys corresponding to this edge are also
// returned.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
var (
ctx = context.TODO()
isZombie bool
pubKey1, pubKey2 route.Vertex
chanIDB = channelIDToBytes(chanID)
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
zombie, err := db.GetZombieChannel(
ctx, sqlc.GetZombieChannelParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
return nil
}
if err != nil {
return fmt.Errorf("unable to fetch zombie channel: %w",
err)
}
copy(pubKey1[:], zombie.NodeKey1)
copy(pubKey2[:], zombie.NodeKey2)
isZombie = true
return nil
}, sqldb.NoOpReset)
if err != nil {
// TODO(elle): update the IsZombieEdge method to return an
// error.
return false, route.Vertex{}, route.Vertex{}
}
return isZombie, pubKey1, pubKey2
}
// NumZombies returns the current number of zombie channels in the graph.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) NumZombies() (uint64, error) {
var (
ctx = context.TODO()
numZombies uint64
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
if err != nil {
return fmt.Errorf("unable to count zombie channels: %w",
err)
}
numZombies = uint64(count)
return nil
}, sqldb.NoOpReset)
if err != nil {
return 0, fmt.Errorf("unable to count zombies: %w", err)
}
return numZombies, nil
}
// forEachNodeDirectedChannel iterates through all channels of a given // forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the // node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is // channel and its incoming policy. If the node is not found, no error is

View File

@ -26,6 +26,19 @@ func (q *Queries) AddSourceNode(ctx context.Context, nodeID int64) error {
return err return err
} }
const countZombieChannels = `-- name: CountZombieChannels :one
SELECT COUNT(*)
FROM zombie_channels
WHERE version = $1
`
func (q *Queries) CountZombieChannels(ctx context.Context, version int16) (int64, error) {
row := q.db.QueryRowContext(ctx, countZombieChannels, version)
var count int64
err := row.Scan(&count)
return count, err
}
const createChannel = `-- name: CreateChannel :one const createChannel = `-- name: CreateChannel :one
/* ───────────────────────────────────────────── /* ─────────────────────────────────────────────
channels table queries channels table queries
@ -168,6 +181,21 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa
return err return err
} }
const deleteZombieChannel = `-- name: DeleteZombieChannel :execresult
DELETE FROM zombie_channels
WHERE scid = $1
AND version = $2
`
type DeleteZombieChannelParams struct {
Scid []byte
Version int16
}
func (q *Queries) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) {
return q.db.ExecContext(ctx, deleteZombieChannel, arg.Scid, arg.Version)
}
const getChannelAndNodesBySCID = `-- name: GetChannelAndNodesBySCID :one const getChannelAndNodesBySCID = `-- name: GetChannelAndNodesBySCID :one
SELECT SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
@ -888,6 +916,30 @@ func (q *Queries) GetSourceNodesByVersion(ctx context.Context, version int16) ([
return items, nil return items, nil
} }
const getZombieChannel = `-- name: GetZombieChannel :one
SELECT scid, version, node_key_1, node_key_2
FROM zombie_channels
WHERE scid = $1
AND version = $2
`
type GetZombieChannelParams struct {
Scid []byte
Version int16
}
func (q *Queries) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error) {
row := q.db.QueryRowContext(ctx, getZombieChannel, arg.Scid, arg.Version)
var i ZombieChannel
err := row.Scan(
&i.Scid,
&i.Version,
&i.NodeKey1,
&i.NodeKey2,
)
return i, err
}
const highestSCID = `-- name: HighestSCID :one const highestSCID = `-- name: HighestSCID :one
SELECT scid SELECT scid
FROM channels FROM channels
@ -1538,3 +1590,36 @@ func (q *Queries) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTy
_, err := q.db.ExecContext(ctx, upsertNodeExtraType, arg.NodeID, arg.Type, arg.Value) _, err := q.db.ExecContext(ctx, upsertNodeExtraType, arg.NodeID, arg.Type, arg.Value)
return err return err
} }
const upsertZombieChannel = `-- name: UpsertZombieChannel :exec
/* ─────────────────────────────────────────────
zombie_channels table queries
─────────────────────────────────────────────
*/
INSERT INTO zombie_channels (scid, version, node_key_1, node_key_2)
VALUES ($1, $2, $3, $4)
ON CONFLICT (scid, version)
DO UPDATE SET
-- If a conflict exists for the SCID and version pair, then we
-- update the node keys.
node_key_1 = COALESCE(EXCLUDED.node_key_1, zombie_channels.node_key_1),
node_key_2 = COALESCE(EXCLUDED.node_key_2, zombie_channels.node_key_2)
`
type UpsertZombieChannelParams struct {
Scid []byte
Version int16
NodeKey1 []byte
NodeKey2 []byte
}
func (q *Queries) UpsertZombieChannel(ctx context.Context, arg UpsertZombieChannelParams) error {
_, err := q.db.ExecContext(ctx, upsertZombieChannel,
arg.Scid,
arg.Version,
arg.NodeKey1,
arg.NodeKey2,
)
return err
}

View File

@ -13,6 +13,7 @@ import (
type Querier interface { type Querier interface {
AddSourceNode(ctx context.Context, nodeID int64) error AddSourceNode(ctx context.Context, nodeID int64) error
ClearKVInvoiceHashIndex(ctx context.Context) error ClearKVInvoiceHashIndex(ctx context.Context) error
CountZombieChannels(ctx context.Context, version int16) (int64, error)
CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error)
CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
@ -22,6 +23,7 @@ type Querier interface {
DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeAddresses(ctx context.Context, nodeID int64) error
DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error)
DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error
DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error)
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error)
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
@ -54,6 +56,7 @@ type Querier interface {
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error)
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error) HighestSCID(ctx context.Context, version int16) ([]byte, error)
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
@ -90,6 +93,7 @@ type Querier interface {
UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error)
UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error)
UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error
UpsertZombieChannel(ctx context.Context, arg UpsertZombieChannelParams) error
} }
var _ Querier = (*Queries)(nil) var _ Querier = (*Queries)(nil)

View File

@ -460,3 +460,34 @@ WHERE cp.id = $1 OR cp.id = $2;
-- name: DeleteChannelPolicyExtraTypes :exec -- name: DeleteChannelPolicyExtraTypes :exec
DELETE FROM channel_policy_extra_types DELETE FROM channel_policy_extra_types
WHERE channel_policy_id = $1; WHERE channel_policy_id = $1;
/* ─────────────────────────────────────────────
zombie_channels table queries
─────────────────────────────────────────────
*/
-- name: UpsertZombieChannel :exec
INSERT INTO zombie_channels (scid, version, node_key_1, node_key_2)
VALUES ($1, $2, $3, $4)
ON CONFLICT (scid, version)
DO UPDATE SET
-- If a conflict exists for the SCID and version pair, then we
-- update the node keys.
node_key_1 = COALESCE(EXCLUDED.node_key_1, zombie_channels.node_key_1),
node_key_2 = COALESCE(EXCLUDED.node_key_2, zombie_channels.node_key_2);
-- name: DeleteZombieChannel :execresult
DELETE FROM zombie_channels
WHERE scid = $1
AND version = $2;
-- name: CountZombieChannels :one
SELECT COUNT(*)
FROM zombie_channels
WHERE version = $1;
-- name: GetZombieChannel :one
SELECT *
FROM zombie_channels
WHERE scid = $1
AND version = $2;