graph/db+sqldb: implement DeleteChannelEdges

This lets us run TestGraphZombieIndex against the SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 17:10:12 +02:00
parent 00b6e0204c
commit 2a6e6683eb
6 changed files with 326 additions and 1 deletions

View File

@@ -86,6 +86,7 @@ circuit. The indices are only available for forwarding events saved after v0.20.
* [4](https://github.com/lightningnetwork/lnd/pull/9931) * [4](https://github.com/lightningnetwork/lnd/pull/9931)
* [5](https://github.com/lightningnetwork/lnd/pull/9935) * [5](https://github.com/lightningnetwork/lnd/pull/9935)
* [6](https://github.com/lightningnetwork/lnd/pull/9936) * [6](https://github.com/lightningnetwork/lnd/pull/9936)
* [7](https://github.com/lightningnetwork/lnd/pull/9937)
## RPC Updates ## RPC Updates

View File

@@ -3811,7 +3811,7 @@ func TestGraphZombieIndex(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// We'll start by creating our test graph along with a test edge. // We'll start by creating our test graph along with a test edge.
graph := MakeTestGraph(t) graph := MakeTestGraphNew(t)
node1 := createTestVertex(t) node1 := createTestVertex(t)
node2 := createTestVertex(t) node2 := createTestVertex(t)

View File

@@ -88,6 +88,7 @@ type SQLQueries interface {
*/ */
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error) HighestSCID(ctx context.Context, version int16) ([]byte, error)
@@ -95,6 +96,7 @@ type SQLQueries interface {
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
DeleteChannel(ctx context.Context, id int64) error
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
@@ -1552,6 +1554,123 @@ func (s *SQLStore) NumZombies() (uint64, error) {
return numZombies, nil return numZombies, nil
} }
// DeleteChannelEdges removes edges with the given channel IDs from the
// database and marks them as zombies. This ensures that we're unable to re-add
// it to our database once again. If an edge does not exist within the
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
// true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state. The markZombie bool
// denotes whether to mark the channel as a zombie.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
var (
ctx = context.TODO()
deleted []*models.ChannelEdgeInfo
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
for _, chanID := range chanIDs {
chanIDB := channelIDToBytes(chanID)
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
return ErrEdgeNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w",
err)
}
node1, node2, err := buildNodeVertices(
row.Node.PubKey, row.Node_2.PubKey,
)
if err != nil {
return err
}
info, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID,
row.Channel, node1, node2,
)
if err != nil {
return err
}
err = db.DeleteChannel(ctx, row.Channel.ID)
if err != nil {
return fmt.Errorf("unable to delete "+
"channel: %w", err)
}
deleted = append(deleted, info)
if !markZombie {
continue
}
nodeKey1, nodeKey2 := info.NodeKey1Bytes,
info.NodeKey2Bytes
if strictZombiePruning {
var e1UpdateTime, e2UpdateTime *time.Time
if row.Policy1LastUpdate.Valid {
e1Time := time.Unix(
row.Policy1LastUpdate.Int64, 0,
)
e1UpdateTime = &e1Time
}
if row.Policy2LastUpdate.Valid {
e2Time := time.Unix(
row.Policy2LastUpdate.Int64, 0,
)
e2UpdateTime = &e2Time
}
nodeKey1, nodeKey2 = makeZombiePubkeys(
info, e1UpdateTime, e2UpdateTime,
)
}
err = db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
NodeKey1: nodeKey1[:],
NodeKey2: nodeKey2[:],
},
)
if err != nil {
return fmt.Errorf("unable to mark channel as "+
"zombie: %w", err)
}
}
return nil
}, func() {
deleted = nil
})
if err != nil {
return nil, fmt.Errorf("unable to delete channel edges: %w",
err)
}
for _, chanID := range chanIDs {
s.rejectCache.remove(chanID)
s.chanCache.remove(chanID)
}
return deleted, nil
}
// forEachNodeDirectedChannel iterates through all channels of a given // forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the // node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is // channel and its incoming policy. If the node is not found, no error is

View File

@@ -114,6 +114,15 @@ func (q *Queries) CreateChannelExtraType(ctx context.Context, arg CreateChannelE
return err return err
} }
const deleteChannel = `-- name: DeleteChannel :exec
DELETE FROM channels WHERE id = $1
`
func (q *Queries) DeleteChannel(ctx context.Context, id int64) error {
_, err := q.db.ExecContext(ctx, deleteChannel, id)
return err
}
const deleteChannelPolicyExtraTypes = `-- name: DeleteChannelPolicyExtraTypes :exec const deleteChannelPolicyExtraTypes = `-- name: DeleteChannelPolicyExtraTypes :exec
DELETE FROM channel_policy_extra_types DELETE FROM channel_policy_extra_types
WHERE channel_policy_id = $1 WHERE channel_policy_id = $1
@@ -285,6 +294,151 @@ func (q *Queries) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDPara
return i, err return i, err
} }
const getChannelBySCIDWithPolicies = `-- name: GetChannelBySCIDWithPolicies :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature,
n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature,
-- Policy 1
cp1.id AS policy1_id,
cp1.node_id AS policy1_node_id,
cp1.version AS policy1_version,
cp1.timelock AS policy1_timelock,
cp1.fee_ppm AS policy1_fee_ppm,
cp1.base_fee_msat AS policy1_base_fee_msat,
cp1.min_htlc_msat AS policy1_min_htlc_msat,
cp1.max_htlc_msat AS policy1_max_htlc_msat,
cp1.last_update AS policy1_last_update,
cp1.disabled AS policy1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy1_signature,
-- Policy 2
cp2.id AS policy2_id,
cp2.node_id AS policy2_node_id,
cp2.version AS policy2_version,
cp2.timelock AS policy2_timelock,
cp2.fee_ppm AS policy2_fee_ppm,
cp2.base_fee_msat AS policy2_base_fee_msat,
cp2.min_htlc_msat AS policy2_min_htlc_msat,
cp2.max_htlc_msat AS policy2_max_htlc_msat,
cp2.last_update AS policy2_last_update,
cp2.disabled AS policy2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.scid = $1
AND c.version = $2
`
type GetChannelBySCIDWithPoliciesParams struct {
Scid []byte
Version int16
}
type GetChannelBySCIDWithPoliciesRow struct {
Channel Channel
Node Node
Node_2 Node
Policy1ID sql.NullInt64
Policy1NodeID sql.NullInt64
Policy1Version sql.NullInt16
Policy1Timelock sql.NullInt32
Policy1FeePpm sql.NullInt64
Policy1BaseFeeMsat sql.NullInt64
Policy1MinHtlcMsat sql.NullInt64
Policy1MaxHtlcMsat sql.NullInt64
Policy1LastUpdate sql.NullInt64
Policy1Disabled sql.NullBool
Policy1InboundBaseFeeMsat sql.NullInt64
Policy1InboundFeeRateMilliMsat sql.NullInt64
Policy1Signature []byte
Policy2ID sql.NullInt64
Policy2NodeID sql.NullInt64
Policy2Version sql.NullInt16
Policy2Timelock sql.NullInt32
Policy2FeePpm sql.NullInt64
Policy2BaseFeeMsat sql.NullInt64
Policy2MinHtlcMsat sql.NullInt64
Policy2MaxHtlcMsat sql.NullInt64
Policy2LastUpdate sql.NullInt64
Policy2Disabled sql.NullBool
Policy2InboundBaseFeeMsat sql.NullInt64
Policy2InboundFeeRateMilliMsat sql.NullInt64
Policy2Signature []byte
}
func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) {
row := q.db.QueryRowContext(ctx, getChannelBySCIDWithPolicies, arg.Scid, arg.Version)
var i GetChannelBySCIDWithPoliciesRow
err := row.Scan(
&i.Channel.ID,
&i.Channel.Version,
&i.Channel.Scid,
&i.Channel.NodeID1,
&i.Channel.NodeID2,
&i.Channel.Outpoint,
&i.Channel.Capacity,
&i.Channel.BitcoinKey1,
&i.Channel.BitcoinKey2,
&i.Channel.Node1Signature,
&i.Channel.Node2Signature,
&i.Channel.Bitcoin1Signature,
&i.Channel.Bitcoin2Signature,
&i.Node.ID,
&i.Node.Version,
&i.Node.PubKey,
&i.Node.Alias,
&i.Node.LastUpdate,
&i.Node.Color,
&i.Node.Signature,
&i.Node_2.ID,
&i.Node_2.Version,
&i.Node_2.PubKey,
&i.Node_2.Alias,
&i.Node_2.LastUpdate,
&i.Node_2.Color,
&i.Node_2.Signature,
&i.Policy1ID,
&i.Policy1NodeID,
&i.Policy1Version,
&i.Policy1Timelock,
&i.Policy1FeePpm,
&i.Policy1BaseFeeMsat,
&i.Policy1MinHtlcMsat,
&i.Policy1MaxHtlcMsat,
&i.Policy1LastUpdate,
&i.Policy1Disabled,
&i.Policy1InboundBaseFeeMsat,
&i.Policy1InboundFeeRateMilliMsat,
&i.Policy1Signature,
&i.Policy2ID,
&i.Policy2NodeID,
&i.Policy2Version,
&i.Policy2Timelock,
&i.Policy2FeePpm,
&i.Policy2BaseFeeMsat,
&i.Policy2MinHtlcMsat,
&i.Policy2MaxHtlcMsat,
&i.Policy2LastUpdate,
&i.Policy2Disabled,
&i.Policy2InboundBaseFeeMsat,
&i.Policy2InboundFeeRateMilliMsat,
&i.Policy2Signature,
)
return i, err
}
const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many
SELECT SELECT
cf.channel_id, cf.channel_id,

View File

@@ -17,6 +17,7 @@ type Querier interface {
CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error)
CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
DeleteChannel(ctx context.Context, id int64) error
DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
@@ -31,6 +32,7 @@ type Querier interface {
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error)
GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error)
GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error)

View File

@@ -373,6 +373,9 @@ WHERE c.version = $1 AND c.id > $2
ORDER BY c.id ORDER BY c.id
LIMIT $3; LIMIT $3;
-- name: DeleteChannel :exec
DELETE FROM channels WHERE id = $1;
/* ───────────────────────────────────────────── /* ─────────────────────────────────────────────
channel_features table queries channel_features table queries
───────────────────────────────────────────── ─────────────────────────────────────────────
@@ -434,6 +437,52 @@ WHERE channel_id = $1
AND node_id = $2 AND node_id = $2
AND version = $3; AND version = $3;
-- name: GetChannelBySCIDWithPolicies :one
SELECT
sqlc.embed(c),
sqlc.embed(n1),
sqlc.embed(n2),
-- Policy 1
cp1.id AS policy1_id,
cp1.node_id AS policy1_node_id,
cp1.version AS policy1_version,
cp1.timelock AS policy1_timelock,
cp1.fee_ppm AS policy1_fee_ppm,
cp1.base_fee_msat AS policy1_base_fee_msat,
cp1.min_htlc_msat AS policy1_min_htlc_msat,
cp1.max_htlc_msat AS policy1_max_htlc_msat,
cp1.last_update AS policy1_last_update,
cp1.disabled AS policy1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy1_signature,
-- Policy 2
cp2.id AS policy2_id,
cp2.node_id AS policy2_node_id,
cp2.version AS policy2_version,
cp2.timelock AS policy2_timelock,
cp2.fee_ppm AS policy2_fee_ppm,
cp2.base_fee_msat AS policy2_base_fee_msat,
cp2.min_htlc_msat AS policy2_min_htlc_msat,
cp2.max_htlc_msat AS policy2_max_htlc_msat,
cp2.last_update AS policy2_last_update,
cp2.disabled AS policy2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.scid = @scid
AND c.version = @version;
/* ───────────────────────────────────────────── /* ─────────────────────────────────────────────
channel_policy_extra_types table queries channel_policy_extra_types table queries
───────────────────────────────────────────── ─────────────────────────────────────────────