graph/db+sqldb: implement FetchChannelEdgesByOutpoint/SCID

And run `TestEdgeInsertionDeletion` against our SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 17:14:41 +02:00
parent 4335d9cfb7
commit 4fad4a7023
5 changed files with 412 additions and 2 deletions

View File

@@ -392,7 +392,7 @@ func TestEdgeInsertionDeletion(t *testing.T) {
t.Parallel()
ctx := context.Background()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// We'd like to test the insertion/deletion of edges, so we create two
// vertexes to connect.

View File

@@ -95,6 +95,7 @@ type SQLQueries interface {
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
DeleteChannel(ctx context.Context, id int64) error
@@ -118,6 +119,7 @@ type SQLQueries interface {
GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
CountZombieChannels(ctx context.Context, version int16) (int64, error)
DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
}
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
@@ -1671,6 +1673,166 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
return deleted, nil
}
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
// channel identified by the channel ID. If the channel can't be found, then
// ErrEdgeNotFound is returned. A struct which houses the general information
// for the channel itself is returned as well as two structs that contain the
// routing policies for the channel in either direction.
//
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
// the ChannelEdgeInfo will only include the public keys of each node.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
var (
ctx = context.TODO()
edge *models.ChannelEdgeInfo
policy1, policy2 *models.ChannelEdgePolicy
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
// First check if this edge is perhaps in the zombie
// index.
isZombie, err := db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if err != nil {
return fmt.Errorf("unable to check if "+
"channel is zombie: %w", err)
} else if isZombie {
return ErrZombieEdge
}
return ErrEdgeNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w", err)
}
node1, node2, err := buildNodeVertices(
row.Node.PubKey, row.Node_2.PubKey,
)
if err != nil {
return err
}
edge, err = getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
policy1, policy2, err = getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
err)
}
return edge, policy1, policy2, nil
}
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
// the channel identified by the funding outpoint. If the channel can't be
// found, then ErrEdgeNotFound is returned. A struct which houses the general
// information for the channel itself is returned as well as two structs that
// contain the routing policies for the channel in either direction.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
var (
ctx = context.TODO()
edge *models.ChannelEdgeInfo
policy1, policy2 *models.ChannelEdgePolicy
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
row, err := db.GetChannelByOutpointWithPolicies(
ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
Outpoint: op.String(),
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
return ErrEdgeNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w", err)
}
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}
edge, err = getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
policy1, policy2, err = getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
err)
}
return edge, policy1, policy2, nil
}
// forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is
@@ -3066,12 +3228,52 @@ func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
// information. It returns two policies, which may be nil if the policy
// information is not present in the row.
//
//nolint:ll,dupl
//nolint:ll,dupl,funlen
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
error) {
var policy1, policy2 *sqlc.ChannelPolicy
switch r := row.(type) {
case sqlc.GetChannelBySCIDWithPoliciesRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.ChannelPolicy{
ID: r.Policy1ID.Int64,
Version: r.Policy1Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy1NodeID.Int64,
Timelock: r.Policy1Timelock.Int32,
FeePpm: r.Policy1FeePpm.Int64,
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
LastUpdate: r.Policy1LastUpdate,
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
Disabled: r.Policy1Disabled,
Signature: r.Policy1Signature,
}
}
if r.Policy2ID.Valid {
policy2 = &sqlc.ChannelPolicy{
ID: r.Policy2ID.Int64,
Version: r.Policy2Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy2NodeID.Int64,
Timelock: r.Policy2Timelock.Int32,
FeePpm: r.Policy2FeePpm.Int64,
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
LastUpdate: r.Policy2LastUpdate,
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
Disabled: r.Policy2Disabled,
Signature: r.Policy2Signature,
}
}
return policy1, policy2, nil
case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.ChannelPolicy{

View File

@@ -263,6 +263,138 @@ func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAn
return i, err
}
const getChannelByOutpointWithPolicies = `-- name: GetChannelByOutpointWithPolicies :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,
-- Node 1 policy
cp1.id AS policy_1_id,
cp1.node_id AS policy_1_node_id,
cp1.version AS policy_1_version,
cp1.timelock AS policy_1_timelock,
cp1.fee_ppm AS policy_1_fee_ppm,
cp1.base_fee_msat AS policy_1_base_fee_msat,
cp1.min_htlc_msat AS policy_1_min_htlc_msat,
cp1.max_htlc_msat AS policy_1_max_htlc_msat,
cp1.last_update AS policy_1_last_update,
cp1.disabled AS policy_1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy_1_signature,
-- Node 2 policy
cp2.id AS policy_2_id,
cp2.node_id AS policy_2_node_id,
cp2.version AS policy_2_version,
cp2.timelock AS policy_2_timelock,
cp2.fee_ppm AS policy_2_fee_ppm,
cp2.base_fee_msat AS policy_2_base_fee_msat,
cp2.min_htlc_msat AS policy_2_min_htlc_msat,
cp2.max_htlc_msat AS policy_2_max_htlc_msat,
cp2.last_update AS policy_2_last_update,
cp2.disabled AS policy_2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy_2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.outpoint = $1 AND c.version = $2
`
type GetChannelByOutpointWithPoliciesParams struct {
Outpoint string
Version int16
}
type GetChannelByOutpointWithPoliciesRow struct {
Channel Channel
Node1Pubkey []byte
Node2Pubkey []byte
Policy1ID sql.NullInt64
Policy1NodeID sql.NullInt64
Policy1Version sql.NullInt16
Policy1Timelock sql.NullInt32
Policy1FeePpm sql.NullInt64
Policy1BaseFeeMsat sql.NullInt64
Policy1MinHtlcMsat sql.NullInt64
Policy1MaxHtlcMsat sql.NullInt64
Policy1LastUpdate sql.NullInt64
Policy1Disabled sql.NullBool
Policy1InboundBaseFeeMsat sql.NullInt64
Policy1InboundFeeRateMilliMsat sql.NullInt64
Policy1Signature []byte
Policy2ID sql.NullInt64
Policy2NodeID sql.NullInt64
Policy2Version sql.NullInt16
Policy2Timelock sql.NullInt32
Policy2FeePpm sql.NullInt64
Policy2BaseFeeMsat sql.NullInt64
Policy2MinHtlcMsat sql.NullInt64
Policy2MaxHtlcMsat sql.NullInt64
Policy2LastUpdate sql.NullInt64
Policy2Disabled sql.NullBool
Policy2InboundBaseFeeMsat sql.NullInt64
Policy2InboundFeeRateMilliMsat sql.NullInt64
Policy2Signature []byte
}
func (q *Queries) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) {
row := q.db.QueryRowContext(ctx, getChannelByOutpointWithPolicies, arg.Outpoint, arg.Version)
var i GetChannelByOutpointWithPoliciesRow
err := row.Scan(
&i.Channel.ID,
&i.Channel.Version,
&i.Channel.Scid,
&i.Channel.NodeID1,
&i.Channel.NodeID2,
&i.Channel.Outpoint,
&i.Channel.Capacity,
&i.Channel.BitcoinKey1,
&i.Channel.BitcoinKey2,
&i.Channel.Node1Signature,
&i.Channel.Node2Signature,
&i.Channel.Bitcoin1Signature,
&i.Channel.Bitcoin2Signature,
&i.Node1Pubkey,
&i.Node2Pubkey,
&i.Policy1ID,
&i.Policy1NodeID,
&i.Policy1Version,
&i.Policy1Timelock,
&i.Policy1FeePpm,
&i.Policy1BaseFeeMsat,
&i.Policy1MinHtlcMsat,
&i.Policy1MaxHtlcMsat,
&i.Policy1LastUpdate,
&i.Policy1Disabled,
&i.Policy1InboundBaseFeeMsat,
&i.Policy1InboundFeeRateMilliMsat,
&i.Policy1Signature,
&i.Policy2ID,
&i.Policy2NodeID,
&i.Policy2Version,
&i.Policy2Timelock,
&i.Policy2FeePpm,
&i.Policy2BaseFeeMsat,
&i.Policy2MinHtlcMsat,
&i.Policy2MaxHtlcMsat,
&i.Policy2LastUpdate,
&i.Policy2Disabled,
&i.Policy2InboundBaseFeeMsat,
&i.Policy2InboundFeeRateMilliMsat,
&i.Policy2Signature,
)
return i, err
}
const getChannelBySCID = `-- name: GetChannelBySCID :one
SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM channels
WHERE scid = $1 AND version = $2
@@ -1211,6 +1343,27 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa
return err
}
const isZombieChannel = `-- name: IsZombieChannel :one
SELECT EXISTS (
SELECT 1
FROM zombie_channels
WHERE scid = $1
AND version = $2
) AS is_zombie
`
type IsZombieChannelParams struct {
Scid []byte
Version int16
}
func (q *Queries) IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error) {
row := q.db.QueryRowContext(ctx, isZombieChannel, arg.Scid, arg.Version)
var is_zombie bool
err := row.Scan(&is_zombie)
return is_zombie, err
}
const listChannelsByNodeID = `-- name: ListChannelsByNodeID :many
SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pubkey,

View File

@@ -31,6 +31,7 @@ type Querier interface {
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error)
@@ -72,6 +73,7 @@ type Querier interface {
InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error)
InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error
InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error
IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error)
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error)
ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error)

View File

@@ -262,6 +262,51 @@ ORDER BY
ELSE COALESCE(cp2.last_update, 0)
END ASC;
-- name: GetChannelByOutpointWithPolicies :one
SELECT
sqlc.embed(c),
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,
-- Node 1 policy
cp1.id AS policy_1_id,
cp1.node_id AS policy_1_node_id,
cp1.version AS policy_1_version,
cp1.timelock AS policy_1_timelock,
cp1.fee_ppm AS policy_1_fee_ppm,
cp1.base_fee_msat AS policy_1_base_fee_msat,
cp1.min_htlc_msat AS policy_1_min_htlc_msat,
cp1.max_htlc_msat AS policy_1_max_htlc_msat,
cp1.last_update AS policy_1_last_update,
cp1.disabled AS policy_1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy_1_signature,
-- Node 2 policy
cp2.id AS policy_2_id,
cp2.node_id AS policy_2_node_id,
cp2.version AS policy_2_version,
cp2.timelock AS policy_2_timelock,
cp2.fee_ppm AS policy_2_fee_ppm,
cp2.base_fee_msat AS policy_2_base_fee_msat,
cp2.min_htlc_msat AS policy_2_min_htlc_msat,
cp2.max_htlc_msat AS policy_2_max_htlc_msat,
cp2.last_update AS policy_2_last_update,
cp2.disabled AS policy_2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy_2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.outpoint = $1 AND c.version = $2;
-- name: HighestSCID :one
SELECT scid
FROM channels
@@ -540,3 +585,11 @@ SELECT *
FROM zombie_channels
WHERE scid = $1
AND version = $2;
-- name: IsZombieChannel :one
SELECT EXISTS (
SELECT 1
FROM zombie_channels
WHERE scid = $1
AND version = $2
) AS is_zombie;