From 00b6e0204c5fb4775e23489daf4be6ae5bf4e4c8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 17:04:21 +0200 Subject: [PATCH] graph/db+sqldb: implement various zombie index methods Here we implement the SQLStore methods: - MarkEdgeZombie - MarkEdgeLive - IsZombieEdge - NumZombies These will be tested in the next commit as one more method implementation is required. --- graph/db/sql_store.go | 162 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 85 ++++++++++++++++++ sqldb/sqlc/querier.go | 4 + sqldb/sqlc/queries/graph.sql | 31 +++++++ 4 files changed, 282 insertions(+) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index f8f64f125..7aff00ad6 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -108,6 +108,14 @@ type SQLQueries interface { InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error + + /* + Zombie index queries. + */ + UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error + GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error) + CountZombieChannels(ctx context.Context, version int16) (int64, error) + DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error) } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -1390,6 +1398,160 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32, }), nil } +// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a +// zombie. This method is used on an ad-hoc basis, when channels need to be +// marked as zombies outside the normal pruning cycle. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) MarkEdgeZombie(chanID uint64, + pubKey1, pubKey2 [33]byte) error { + + ctx := context.TODO() + + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + + chanIDB := channelIDToBytes(chanID) + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + return db.UpsertZombieChannel( + ctx, sqlc.UpsertZombieChannelParams{ + Version: int16(ProtocolV1), + Scid: chanIDB[:], + NodeKey1: pubKey1[:], + NodeKey2: pubKey2[:], + }, + ) + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("unable to upsert zombie channel "+ + "(channel_id=%d): %w", chanID, err) + } + + s.rejectCache.remove(chanID) + s.chanCache.remove(chanID) + + return nil +} + +// MarkEdgeLive clears an edge from our zombie index, deeming it as live. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) MarkEdgeLive(chanID uint64) error { + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + + var ( + ctx = context.TODO() + chanIDB = channelIDToBytes(chanID) + ) + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + res, err := db.DeleteZombieChannel( + ctx, sqlc.DeleteZombieChannelParams{ + Scid: chanIDB[:], + Version: int16(ProtocolV1), + }, + ) + if err != nil { + return fmt.Errorf("unable to delete zombie channel: %w", + err) + } + + rows, err := res.RowsAffected() + if err != nil { + return err + } + + if rows == 0 { + return ErrZombieEdgeNotFound + } else if rows > 1 { + return fmt.Errorf("deleted %d zombie rows, "+ + "expected 1", rows) + } + + return nil + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("unable to mark edge live "+ + "(channel_id=%d): %w", chanID, err) + } + + s.rejectCache.remove(chanID) + s.chanCache.remove(chanID) + + return err +} + +// IsZombieEdge returns whether the edge is considered zombie. If it is a +// zombie, then the two node public keys corresponding to this edge are also +// returned. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { + var ( + ctx = context.TODO() + isZombie bool + pubKey1, pubKey2 route.Vertex + chanIDB = channelIDToBytes(chanID) + ) + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + zombie, err := db.GetZombieChannel( + ctx, sqlc.GetZombieChannelParams{ + Scid: chanIDB[:], + Version: int16(ProtocolV1), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil + } + if err != nil { + return fmt.Errorf("unable to fetch zombie channel: %w", + err) + } + + copy(pubKey1[:], zombie.NodeKey1) + copy(pubKey2[:], zombie.NodeKey2) + isZombie = true + + return nil + }, sqldb.NoOpReset) + if err != nil { + // TODO(elle): update the IsZombieEdge method to return an + // error. + return false, route.Vertex{}, route.Vertex{} + } + + return isZombie, pubKey1, pubKey2 +} + +// NumZombies returns the current number of zombie channels in the graph. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) NumZombies() (uint64, error) { + var ( + ctx = context.TODO() + numZombies uint64 + ) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + count, err := db.CountZombieChannels(ctx, int16(ProtocolV1)) + if err != nil { + return fmt.Errorf("unable to count zombie channels: %w", + err) + } + + numZombies = uint64(count) + + return nil + }, sqldb.NoOpReset) + if err != nil { + return 0, fmt.Errorf("unable to count zombies: %w", err) + } + + return numZombies, nil +} + // forEachNodeDirectedChannel iterates through all channels of a given // node, executing the passed callback on the directed edge representing the // channel and its incoming policy. If the node is not found, no error is diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 76e61fe8c..f5f6197a0 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -26,6 +26,19 @@ func (q *Queries) AddSourceNode(ctx context.Context, nodeID int64) error { return err } +const countZombieChannels = `-- name: CountZombieChannels :one +SELECT COUNT(*) +FROM zombie_channels +WHERE version = $1 +` + +func (q *Queries) CountZombieChannels(ctx context.Context, version int16) (int64, error) { + row := q.db.QueryRowContext(ctx, countZombieChannels, version) + var count int64 + err := row.Scan(&count) + return count, err +} + const createChannel = `-- name: CreateChannel :one /* ───────────────────────────────────────────── channels table queries @@ -168,6 +181,21 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa return err } +const deleteZombieChannel = `-- name: DeleteZombieChannel :execresult +DELETE FROM zombie_channels +WHERE scid = $1 +AND version = $2 +` + +type DeleteZombieChannelParams struct { + Scid []byte + Version int16 +} + +func (q *Queries) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) { + return q.db.ExecContext(ctx, deleteZombieChannel, arg.Scid, arg.Version) +} + const getChannelAndNodesBySCID = `-- name: GetChannelAndNodesBySCID :one SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, @@ -888,6 +916,30 @@ func (q *Queries) GetSourceNodesByVersion(ctx context.Context, version int16) ([ return items, nil } +const getZombieChannel = `-- name: GetZombieChannel :one +SELECT scid, version, node_key_1, node_key_2 +FROM zombie_channels +WHERE scid = $1 +AND version = $2 +` + +type GetZombieChannelParams struct { + Scid []byte + Version int16 +} + +func (q *Queries) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error) { + row := q.db.QueryRowContext(ctx, getZombieChannel, arg.Scid, arg.Version) + var i ZombieChannel + err := row.Scan( + &i.Scid, + &i.Version, + &i.NodeKey1, + &i.NodeKey2, + ) + return i, err +} + const highestSCID = `-- name: HighestSCID :one SELECT scid FROM channels @@ -1538,3 +1590,36 @@ func (q *Queries) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTy _, err := q.db.ExecContext(ctx, upsertNodeExtraType, arg.NodeID, arg.Type, arg.Value) return err } + +const upsertZombieChannel = `-- name: UpsertZombieChannel :exec +/* ───────────────────────────────────────────── + zombie_channels table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO zombie_channels (scid, version, node_key_1, node_key_2) +VALUES ($1, $2, $3, $4) +ON CONFLICT (scid, version) +DO UPDATE SET + -- If a conflict exists for the SCID and version pair, then we + -- update the node keys. + node_key_1 = COALESCE(EXCLUDED.node_key_1, zombie_channels.node_key_1), + node_key_2 = COALESCE(EXCLUDED.node_key_2, zombie_channels.node_key_2) +` + +type UpsertZombieChannelParams struct { + Scid []byte + Version int16 + NodeKey1 []byte + NodeKey2 []byte +} + +func (q *Queries) UpsertZombieChannel(ctx context.Context, arg UpsertZombieChannelParams) error { + _, err := q.db.ExecContext(ctx, upsertZombieChannel, + arg.Scid, + arg.Version, + arg.NodeKey1, + arg.NodeKey2, + ) + return err +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 8174abcd9..1796d96f2 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -13,6 +13,7 @@ import ( type Querier interface { AddSourceNode(ctx context.Context, nodeID int64) error ClearKVInvoiceHashIndex(ctx context.Context) error + CountZombieChannels(ctx context.Context, version int16) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) @@ -22,6 +23,7 @@ type Querier interface { DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error + DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) @@ -54,6 +56,7 @@ type Querier interface { GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) + GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error @@ -90,6 +93,7 @@ type Querier interface { UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error + UpsertZombieChannel(ctx context.Context, arg UpsertZombieChannelParams) error } var _ Querier = (*Queries)(nil) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5560ec90d..a8a5362d2 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -460,3 +460,34 @@ WHERE cp.id = $1 OR cp.id = $2; -- name: DeleteChannelPolicyExtraTypes :exec DELETE FROM channel_policy_extra_types WHERE channel_policy_id = $1; + +/* ───────────────────────────────────────────── + zombie_channels table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertZombieChannel :exec +INSERT INTO zombie_channels (scid, version, node_key_1, node_key_2) +VALUES ($1, $2, $3, $4) +ON CONFLICT (scid, version) +DO UPDATE SET + -- If a conflict exists for the SCID and version pair, then we + -- update the node keys. + node_key_1 = COALESCE(EXCLUDED.node_key_1, zombie_channels.node_key_1), + node_key_2 = COALESCE(EXCLUDED.node_key_2, zombie_channels.node_key_2); + +-- name: DeleteZombieChannel :execresult +DELETE FROM zombie_channels +WHERE scid = $1 +AND version = $2; + +-- name: CountZombieChannels :one +SELECT COUNT(*) +FROM zombie_channels +WHERE version = $1; + +-- name: GetZombieChannel :one +SELECT * +FROM zombie_channels +WHERE scid = $1 +AND version = $2;