graph/db+sqldb: implement UpdateEdgePolicy

In this commit, the various SQL queries are defined that we will need in
order to implement the SQLStore UpdateEdgePolicy method. Channel
policies can be "replaced" and so we use the upsert pattern for them
with the rule that any new channel policy must have a timestamp greater
than the previous one we persisted.

As is done for the KVStore implementation of the method, we use the
batch scheduler for this method.
This commit is contained in:
Elle Mouton
2025-06-11 14:25:39 +02:00
parent 498a18d028
commit c327988bb3
5 changed files with 451 additions and 1 deletions

View File

@ -4038,7 +4038,7 @@ func TestBatchedAddChannelEdge(t *testing.T) {
func TestBatchedUpdateEdgePolicy(t *testing.T) {
t.Parallel()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// We'd like to test the update of edges inserted into the database, so
// we create two vertexes to connect.

View File

@ -76,10 +76,19 @@ type SQLQueries interface {
*/
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
/*
Channel Policy table queries.
*/
UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
}
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
@ -552,6 +561,193 @@ func (s *SQLStore) HighestChanID() (uint64, error) {
return highestChanID, nil
}
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
// within the database for the referenced channel. The `flags` attribute within
// the ChannelEdgePolicy determines which of the directed edges are being
// updated. If the flag is 1, then the first node's information is being
// updated, otherwise it's the second node's information. The node ordering is
// determined by the lexicographical ordering of the identity public keys of the
// nodes on either side of the channel.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
ctx := context.TODO()
var (
isUpdate1 bool
edgeNotFound bool
from, to route.Vertex
)
r := &batch.Request[SQLQueries]{
Opts: batch.NewSchedulerOptions(opts...),
Reset: func() {
isUpdate1 = false
edgeNotFound = false
},
Do: func(tx SQLQueries) error {
var err error
from, to, isUpdate1, err = updateChanEdgePolicy(
ctx, tx, edge,
)
if err != nil {
log.Errorf("UpdateEdgePolicy faild: %v", err)
}
// Silence ErrEdgeNotFound so that the batch can
// succeed, but propagate the error via local state.
if errors.Is(err, ErrEdgeNotFound) {
edgeNotFound = true
return nil
}
return err
},
OnCommit: func(err error) error {
switch {
case err != nil:
return err
case edgeNotFound:
return ErrEdgeNotFound
default:
s.updateEdgeCache(edge, isUpdate1)
return nil
}
},
}
err := s.chanScheduler.Execute(ctx, r)
return from, to, err
}
// updateEdgeCache updates our reject and channel caches with the new
// edge policy information.
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
isUpdate1 bool) {
// If an entry for this channel is found in reject cache, we'll modify
// the entry with the updated timestamp for the direction that was just
// written. If the edge doesn't exist, we'll load the cache entry lazily
// during the next query for this edge.
if entry, ok := s.rejectCache.get(e.ChannelID); ok {
if isUpdate1 {
entry.upd1Time = e.LastUpdate.Unix()
} else {
entry.upd2Time = e.LastUpdate.Unix()
}
s.rejectCache.insert(e.ChannelID, entry)
}
// If an entry for this channel is found in channel cache, we'll modify
// the entry with the updated policy for the direction that was just
// written. If the edge doesn't exist, we'll defer loading the info and
// policies and lazily read from disk during the next query.
if channel, ok := s.chanCache.get(e.ChannelID); ok {
if isUpdate1 {
channel.Policy1 = e
} else {
channel.Policy2 = e
}
s.chanCache.insert(e.ChannelID, channel)
}
}
// updateChanEdgePolicy upserts the channel policy info we have stored for
// a channel we already know of.
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
error) {
var (
node1Pub, node2Pub route.Vertex
isNode1 bool
chanIDB [8]byte
)
byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
// Check that this edge policy refers to a channel that we already
// know of. We do this explicitly so that we can return the appropriate
// ErrEdgeNotFound error if the channel doesn't exist, rather than
// abort the transaction which would abort the entire batch.
dbChan, err := tx.GetChannelAndNodesBySCID(
ctx, sqlc.GetChannelAndNodesBySCIDParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
return node1Pub, node2Pub, false, ErrEdgeNotFound
} else if err != nil {
return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
"fetch channel(%v): %w", edge.ChannelID, err)
}
copy(node1Pub[:], dbChan.Node1PubKey)
copy(node2Pub[:], dbChan.Node2PubKey)
// Figure out which node this edge is from.
isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
nodeID := dbChan.NodeID1
if !isNode1 {
nodeID = dbChan.NodeID2
}
var (
inboundBase sql.NullInt64
inboundRate sql.NullInt64
)
edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
inboundRate = sqldb.SQLInt64(fee.FeeRate)
inboundBase = sqldb.SQLInt64(fee.BaseFee)
})
id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
Version: int16(ProtocolV1),
ChannelID: dbChan.ID,
NodeID: nodeID,
Timelock: int32(edge.TimeLockDelta),
FeePpm: int64(edge.FeeProportionalMillionths),
BaseFeeMsat: int64(edge.FeeBaseMSat),
MinHtlcMsat: int64(edge.MinHTLC),
LastUpdate: sqldb.SQLInt64(edge.LastUpdate.Unix()),
Disabled: sql.NullBool{
Valid: true,
Bool: edge.IsDisabled(),
},
MaxHtlcMsat: sql.NullInt64{
Valid: edge.MessageFlags.HasMaxHtlc(),
Int64: int64(edge.MaxHTLC),
},
InboundBaseFeeMsat: inboundBase,
InboundFeeRateMilliMsat: inboundRate,
Signature: edge.SigBytes,
})
if err != nil {
return node1Pub, node2Pub, isNode1,
fmt.Errorf("unable to upsert edge policy: %w", err)
}
// Convert the flat extra opaque data into a map of TLV types to
// values.
extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
if err != nil {
return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
"marshal extra opaque data: %w", err)
}
// Update the channel policy's extra signed fields.
err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
if err != nil {
return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
"policy extra TLVs: %w", err)
}
return node1Pub, node2Pub, isNode1, nil
}
// getNodeByPubKey attempts to look up a target node by its public key.
func getNodeByPubKey(ctx context.Context, db SQLQueries,
pubKey route.Vertex) (int64, *models.LightningNode, error) {
@ -1267,3 +1463,36 @@ func maybeCreateShellNode(ctx context.Context, db SQLQueries,
return id, nil
}
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
// the database. This includes deleting any existing types and then inserting
// the new types.
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
chanPolicyID int64, extraFields map[uint64][]byte) error {
// Delete all existing extra signed fields for the channel policy.
err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
if err != nil {
return fmt.Errorf("unable to delete "+
"existing policy extra signed fields for policy %d: %w",
chanPolicyID, err)
}
// Insert all new extra signed fields for the channel policy.
for tlvType, value := range extraFields {
err = db.InsertChanPolicyExtraType(
ctx, sqlc.InsertChanPolicyExtraTypeParams{
ChannelPolicyID: chanPolicyID,
Type: int64(tlvType),
Value: value,
},
)
if err != nil {
return fmt.Errorf("unable to insert "+
"channel_policy(%d) extra signed field(%v): %w",
chanPolicyID, tlvType, err)
}
}
return nil
}

View File

@ -101,6 +101,16 @@ func (q *Queries) CreateChannelExtraType(ctx context.Context, arg CreateChannelE
return err
}
const deleteChannelPolicyExtraTypes = `-- name: DeleteChannelPolicyExtraTypes :exec
DELETE FROM channel_policy_extra_types
WHERE channel_policy_id = $1
`
func (q *Queries) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error {
_, err := q.db.ExecContext(ctx, deleteChannelPolicyExtraTypes, channelPolicyID)
return err
}
const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec
DELETE FROM node_extra_types
WHERE node_id = $1
@ -158,6 +168,64 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa
return err
}
const getChannelAndNodesBySCID = `-- name: GetChannelAndNodesBySCID :one
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pub_key,
n2.pub_key AS node2_pub_key
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE c.scid = $1
AND c.version = $2
`
type GetChannelAndNodesBySCIDParams struct {
Scid []byte
Version int16
}
type GetChannelAndNodesBySCIDRow struct {
ID int64
Version int16
Scid []byte
NodeID1 int64
NodeID2 int64
Outpoint string
Capacity sql.NullInt64
BitcoinKey1 []byte
BitcoinKey2 []byte
Node1Signature []byte
Node2Signature []byte
Bitcoin1Signature []byte
Bitcoin2Signature []byte
Node1PubKey []byte
Node2PubKey []byte
}
func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) {
row := q.db.QueryRowContext(ctx, getChannelAndNodesBySCID, arg.Scid, arg.Version)
var i GetChannelAndNodesBySCIDRow
err := row.Scan(
&i.ID,
&i.Version,
&i.Scid,
&i.NodeID1,
&i.NodeID2,
&i.Outpoint,
&i.Capacity,
&i.BitcoinKey1,
&i.BitcoinKey2,
&i.Node1Signature,
&i.Node2Signature,
&i.Bitcoin1Signature,
&i.Bitcoin2Signature,
&i.Node1PubKey,
&i.Node2PubKey,
)
return i, err
}
const getChannelBySCID = `-- name: GetChannelBySCID :one
SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM channels
WHERE scid = $1 AND version = $2
@ -444,6 +512,29 @@ func (q *Queries) HighestSCID(ctx context.Context, version int16) ([]byte, error
return scid, err
}
const insertChanPolicyExtraType = `-- name: InsertChanPolicyExtraType :exec
/* ─────────────────────────────────────────────
channel_policy_extra_types table queries
─────────────────────────────────────────────
*/
INSERT INTO channel_policy_extra_types (
channel_policy_id, type, value
)
VALUES ($1, $2, $3)
`
type InsertChanPolicyExtraTypeParams struct {
ChannelPolicyID int64
Type int64
Value []byte
}
func (q *Queries) InsertChanPolicyExtraType(ctx context.Context, arg InsertChanPolicyExtraTypeParams) error {
_, err := q.db.ExecContext(ctx, insertChanPolicyExtraType, arg.ChannelPolicyID, arg.Type, arg.Value)
return err
}
const insertChannelFeature = `-- name: InsertChannelFeature :exec
/* ─────────────────────────────────────────────
channel_features table queries
@ -523,6 +614,75 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa
return err
}
const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one
/* ─────────────────────────────────────────────
channel_policies table queries
─────────────────────────────────────────────
*/
INSERT INTO channel_policies (
version, channel_id, node_id, timelock, fee_ppm,
base_fee_msat, min_htlc_msat, last_update, disabled,
max_htlc_msat, inbound_base_fee_msat,
inbound_fee_rate_milli_msat, signature
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
ON CONFLICT (channel_id, node_id, version)
-- Update the following fields if a conflict occurs on channel_id,
-- node_id, and version.
DO UPDATE SET
timelock = EXCLUDED.timelock,
fee_ppm = EXCLUDED.fee_ppm,
base_fee_msat = EXCLUDED.base_fee_msat,
min_htlc_msat = EXCLUDED.min_htlc_msat,
last_update = EXCLUDED.last_update,
disabled = EXCLUDED.disabled,
max_htlc_msat = EXCLUDED.max_htlc_msat,
inbound_base_fee_msat = EXCLUDED.inbound_base_fee_msat,
inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat,
signature = EXCLUDED.signature
WHERE EXCLUDED.last_update > channel_policies.last_update
RETURNING id
`
type UpsertEdgePolicyParams struct {
Version int16
ChannelID int64
NodeID int64
Timelock int32
FeePpm int64
BaseFeeMsat int64
MinHtlcMsat int64
LastUpdate sql.NullInt64
Disabled sql.NullBool
MaxHtlcMsat sql.NullInt64
InboundBaseFeeMsat sql.NullInt64
InboundFeeRateMilliMsat sql.NullInt64
Signature []byte
}
func (q *Queries) UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) {
row := q.db.QueryRowContext(ctx, upsertEdgePolicy,
arg.Version,
arg.ChannelID,
arg.NodeID,
arg.Timelock,
arg.FeePpm,
arg.BaseFeeMsat,
arg.MinHtlcMsat,
arg.LastUpdate,
arg.Disabled,
arg.MaxHtlcMsat,
arg.InboundBaseFeeMsat,
arg.InboundFeeRateMilliMsat,
arg.Signature,
)
var id int64
err := row.Scan(&id)
return id, err
}
const upsertNode = `-- name: UpsertNode :one
/* ─────────────────────────────────────────────
nodes table queries

View File

@ -16,6 +16,7 @@ type Querier interface {
CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error)
CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
DeleteNodeAddresses(ctx context.Context, nodeID int64) error
@ -26,6 +27,7 @@ type Querier interface {
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error)
GetDatabaseVersion(ctx context.Context) (int32, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error)
@ -49,6 +51,7 @@ type Querier interface {
HighestSCID(ctx context.Context, version int16) ([]byte, error)
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
InsertChanPolicyExtraType(ctx context.Context, arg InsertChanPolicyExtraTypeParams) error
InsertChannelFeature(ctx context.Context, arg InsertChannelFeatureParams) error
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)
InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error
@ -74,6 +77,7 @@ type Querier interface {
UpdateInvoiceHTLCs(ctx context.Context, arg UpdateInvoiceHTLCsParams) error
UpdateInvoiceState(ctx context.Context, arg UpdateInvoiceStateParams) (sql.Result, error)
UpsertAMPSubInvoice(ctx context.Context, arg UpsertAMPSubInvoiceParams) (sql.Result, error)
UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error)
UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error)
UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error
}

View File

@ -154,6 +154,17 @@ RETURNING id;
SELECT * FROM channels
WHERE scid = $1 AND version = $2;
-- name: GetChannelAndNodesBySCID :one
SELECT
c.*,
n1.pub_key AS node1_pub_key,
n2.pub_key AS node2_pub_key
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE c.scid = $1
AND c.version = $2;
-- name: HighestSCID :one
SELECT scid
FROM channels
@ -183,3 +194,49 @@ INSERT INTO channel_extra_types (
channel_id, type, value
)
VALUES ($1, $2, $3);
/* ─────────────────────────────────────────────
channel_policies table queries
─────────────────────────────────────────────
*/
-- name: UpsertEdgePolicy :one
INSERT INTO channel_policies (
version, channel_id, node_id, timelock, fee_ppm,
base_fee_msat, min_htlc_msat, last_update, disabled,
max_htlc_msat, inbound_base_fee_msat,
inbound_fee_rate_milli_msat, signature
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
ON CONFLICT (channel_id, node_id, version)
-- Update the following fields if a conflict occurs on channel_id,
-- node_id, and version.
DO UPDATE SET
timelock = EXCLUDED.timelock,
fee_ppm = EXCLUDED.fee_ppm,
base_fee_msat = EXCLUDED.base_fee_msat,
min_htlc_msat = EXCLUDED.min_htlc_msat,
last_update = EXCLUDED.last_update,
disabled = EXCLUDED.disabled,
max_htlc_msat = EXCLUDED.max_htlc_msat,
inbound_base_fee_msat = EXCLUDED.inbound_base_fee_msat,
inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat,
signature = EXCLUDED.signature
WHERE EXCLUDED.last_update > channel_policies.last_update
RETURNING id;
/* ─────────────────────────────────────────────
channel_policy_extra_types table queries
─────────────────────────────────────────────
*/
-- name: InsertChanPolicyExtraType :exec
INSERT INTO channel_policy_extra_types (
channel_policy_id, type, value
)
VALUES ($1, $2, $3);
-- name: DeleteChannelPolicyExtraTypes :exec
DELETE FROM channel_policy_extra_types
WHERE channel_policy_id = $1;