diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index efe1b4642..18d3e52a5 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -4038,7 +4038,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { func TestBatchedUpdateEdgePolicy(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index bc1a92e76..2ae133480 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -76,10 +76,19 @@ type SQLQueries interface { */ CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) + GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error + + /* + Channel Policy table queries. + */ + UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error) + + InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error + DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -552,6 +561,193 @@ func (s *SQLStore) HighestChanID() (uint64, error) { return highestChanID, nil } +// UpdateEdgePolicy updates the edge routing policy for a single directed edge +// within the database for the referenced channel. The `flags` attribute within +// the ChannelEdgePolicy determines which of the directed edges are being +// updated. If the flag is 1, then the first node's information is being +// updated, otherwise it's the second node's information. The node ordering is +// determined by the lexicographical ordering of the identity public keys of the +// nodes on either side of the channel. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, + opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) { + + ctx := context.TODO() + + var ( + isUpdate1 bool + edgeNotFound bool + from, to route.Vertex + ) + + r := &batch.Request[SQLQueries]{ + Opts: batch.NewSchedulerOptions(opts...), + Reset: func() { + isUpdate1 = false + edgeNotFound = false + }, + Do: func(tx SQLQueries) error { + var err error + from, to, isUpdate1, err = updateChanEdgePolicy( + ctx, tx, edge, + ) + if err != nil { + log.Errorf("UpdateEdgePolicy faild: %v", err) + } + + // Silence ErrEdgeNotFound so that the batch can + // succeed, but propagate the error via local state. + if errors.Is(err, ErrEdgeNotFound) { + edgeNotFound = true + return nil + } + + return err + }, + OnCommit: func(err error) error { + switch { + case err != nil: + return err + case edgeNotFound: + return ErrEdgeNotFound + default: + s.updateEdgeCache(edge, isUpdate1) + return nil + } + }, + } + + err := s.chanScheduler.Execute(ctx, r) + + return from, to, err +} + +// updateEdgeCache updates our reject and channel caches with the new +// edge policy information. +func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy, + isUpdate1 bool) { + + // If an entry for this channel is found in reject cache, we'll modify + // the entry with the updated timestamp for the direction that was just + // written. If the edge doesn't exist, we'll load the cache entry lazily + // during the next query for this edge. + if entry, ok := s.rejectCache.get(e.ChannelID); ok { + if isUpdate1 { + entry.upd1Time = e.LastUpdate.Unix() + } else { + entry.upd2Time = e.LastUpdate.Unix() + } + s.rejectCache.insert(e.ChannelID, entry) + } + + // If an entry for this channel is found in channel cache, we'll modify + // the entry with the updated policy for the direction that was just + // written. If the edge doesn't exist, we'll defer loading the info and + // policies and lazily read from disk during the next query. + if channel, ok := s.chanCache.get(e.ChannelID); ok { + if isUpdate1 { + channel.Policy1 = e + } else { + channel.Policy2 = e + } + s.chanCache.insert(e.ChannelID, channel) + } +} + +// updateChanEdgePolicy upserts the channel policy info we have stored for +// a channel we already know of. +func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, + edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool, + error) { + + var ( + node1Pub, node2Pub route.Vertex + isNode1 bool + chanIDB [8]byte + ) + byteOrder.PutUint64(chanIDB[:], edge.ChannelID) + + // Check that this edge policy refers to a channel that we already + // know of. We do this explicitly so that we can return the appropriate + // ErrEdgeNotFound error if the channel doesn't exist, rather than + // abort the transaction which would abort the entire batch. + dbChan, err := tx.GetChannelAndNodesBySCID( + ctx, sqlc.GetChannelAndNodesBySCIDParams{ + Scid: chanIDB[:], + Version: int16(ProtocolV1), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return node1Pub, node2Pub, false, ErrEdgeNotFound + } else if err != nil { + return node1Pub, node2Pub, false, fmt.Errorf("unable to "+ + "fetch channel(%v): %w", edge.ChannelID, err) + } + + copy(node1Pub[:], dbChan.Node1PubKey) + copy(node2Pub[:], dbChan.Node2PubKey) + + // Figure out which node this edge is from. + isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 + nodeID := dbChan.NodeID1 + if !isNode1 { + nodeID = dbChan.NodeID2 + } + + var ( + inboundBase sql.NullInt64 + inboundRate sql.NullInt64 + ) + edge.InboundFee.WhenSome(func(fee lnwire.Fee) { + inboundRate = sqldb.SQLInt64(fee.FeeRate) + inboundBase = sqldb.SQLInt64(fee.BaseFee) + }) + + id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{ + Version: int16(ProtocolV1), + ChannelID: dbChan.ID, + NodeID: nodeID, + Timelock: int32(edge.TimeLockDelta), + FeePpm: int64(edge.FeeProportionalMillionths), + BaseFeeMsat: int64(edge.FeeBaseMSat), + MinHtlcMsat: int64(edge.MinHTLC), + LastUpdate: sqldb.SQLInt64(edge.LastUpdate.Unix()), + Disabled: sql.NullBool{ + Valid: true, + Bool: edge.IsDisabled(), + }, + MaxHtlcMsat: sql.NullInt64{ + Valid: edge.MessageFlags.HasMaxHtlc(), + Int64: int64(edge.MaxHTLC), + }, + InboundBaseFeeMsat: inboundBase, + InboundFeeRateMilliMsat: inboundRate, + Signature: edge.SigBytes, + }) + if err != nil { + return node1Pub, node2Pub, isNode1, + fmt.Errorf("unable to upsert edge policy: %w", err) + } + + // Convert the flat extra opaque data into a map of TLV types to + // values. + extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData) + if err != nil { + return node1Pub, node2Pub, false, fmt.Errorf("unable to "+ + "marshal extra opaque data: %w", err) + } + + // Update the channel policy's extra signed fields. + err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra) + if err != nil { + return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+ + "policy extra TLVs: %w", err) + } + + return node1Pub, node2Pub, isNode1, nil +} + // getNodeByPubKey attempts to look up a target node by its public key. func getNodeByPubKey(ctx context.Context, db SQLQueries, pubKey route.Vertex) (int64, *models.LightningNode, error) { @@ -1267,3 +1463,36 @@ func maybeCreateShellNode(ctx context.Context, db SQLQueries, return id, nil } + +// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in +// the database. This includes deleting any existing types and then inserting +// the new types. +func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries, + chanPolicyID int64, extraFields map[uint64][]byte) error { + + // Delete all existing extra signed fields for the channel policy. + err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID) + if err != nil { + return fmt.Errorf("unable to delete "+ + "existing policy extra signed fields for policy %d: %w", + chanPolicyID, err) + } + + // Insert all new extra signed fields for the channel policy. + for tlvType, value := range extraFields { + err = db.InsertChanPolicyExtraType( + ctx, sqlc.InsertChanPolicyExtraTypeParams{ + ChannelPolicyID: chanPolicyID, + Type: int64(tlvType), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("unable to insert "+ + "channel_policy(%d) extra signed field(%v): %w", + chanPolicyID, tlvType, err) + } + } + + return nil +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 2c42e535a..82385008d 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -101,6 +101,16 @@ func (q *Queries) CreateChannelExtraType(ctx context.Context, arg CreateChannelE return err } +const deleteChannelPolicyExtraTypes = `-- name: DeleteChannelPolicyExtraTypes :exec +DELETE FROM channel_policy_extra_types +WHERE channel_policy_id = $1 +` + +func (q *Queries) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error { + _, err := q.db.ExecContext(ctx, deleteChannelPolicyExtraTypes, channelPolicyID) + return err +} + const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec DELETE FROM node_extra_types WHERE node_id = $1 @@ -158,6 +168,64 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa return err } +const getChannelAndNodesBySCID = `-- name: GetChannelAndNodesBySCID :one +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pub_key, + n2.pub_key AS node2_pub_key +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id +WHERE c.scid = $1 + AND c.version = $2 +` + +type GetChannelAndNodesBySCIDParams struct { + Scid []byte + Version int16 +} + +type GetChannelAndNodesBySCIDRow struct { + ID int64 + Version int16 + Scid []byte + NodeID1 int64 + NodeID2 int64 + Outpoint string + Capacity sql.NullInt64 + BitcoinKey1 []byte + BitcoinKey2 []byte + Node1Signature []byte + Node2Signature []byte + Bitcoin1Signature []byte + Bitcoin2Signature []byte + Node1PubKey []byte + Node2PubKey []byte +} + +func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) { + row := q.db.QueryRowContext(ctx, getChannelAndNodesBySCID, arg.Scid, arg.Version) + var i GetChannelAndNodesBySCIDRow + err := row.Scan( + &i.ID, + &i.Version, + &i.Scid, + &i.NodeID1, + &i.NodeID2, + &i.Outpoint, + &i.Capacity, + &i.BitcoinKey1, + &i.BitcoinKey2, + &i.Node1Signature, + &i.Node2Signature, + &i.Bitcoin1Signature, + &i.Bitcoin2Signature, + &i.Node1PubKey, + &i.Node2PubKey, + ) + return i, err +} + const getChannelBySCID = `-- name: GetChannelBySCID :one SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM channels WHERE scid = $1 AND version = $2 @@ -444,6 +512,29 @@ func (q *Queries) HighestSCID(ctx context.Context, version int16) ([]byte, error return scid, err } +const insertChanPolicyExtraType = `-- name: InsertChanPolicyExtraType :exec +/* ───────────────────────────────────────────── + channel_policy_extra_types table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO channel_policy_extra_types ( + channel_policy_id, type, value +) +VALUES ($1, $2, $3) +` + +type InsertChanPolicyExtraTypeParams struct { + ChannelPolicyID int64 + Type int64 + Value []byte +} + +func (q *Queries) InsertChanPolicyExtraType(ctx context.Context, arg InsertChanPolicyExtraTypeParams) error { + _, err := q.db.ExecContext(ctx, insertChanPolicyExtraType, arg.ChannelPolicyID, arg.Type, arg.Value) + return err +} + const insertChannelFeature = `-- name: InsertChannelFeature :exec /* ───────────────────────────────────────────── channel_features table queries @@ -523,6 +614,75 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa return err } +const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one +/* ───────────────────────────────────────────── + channel_policies table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO channel_policies ( + version, channel_id, node_id, timelock, fee_ppm, + base_fee_msat, min_htlc_msat, last_update, disabled, + max_htlc_msat, inbound_base_fee_msat, + inbound_fee_rate_milli_msat, signature +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 +) +ON CONFLICT (channel_id, node_id, version) + -- Update the following fields if a conflict occurs on channel_id, + -- node_id, and version. + DO UPDATE SET + timelock = EXCLUDED.timelock, + fee_ppm = EXCLUDED.fee_ppm, + base_fee_msat = EXCLUDED.base_fee_msat, + min_htlc_msat = EXCLUDED.min_htlc_msat, + last_update = EXCLUDED.last_update, + disabled = EXCLUDED.disabled, + max_htlc_msat = EXCLUDED.max_htlc_msat, + inbound_base_fee_msat = EXCLUDED.inbound_base_fee_msat, + inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > channel_policies.last_update +RETURNING id +` + +type UpsertEdgePolicyParams struct { + Version int16 + ChannelID int64 + NodeID int64 + Timelock int32 + FeePpm int64 + BaseFeeMsat int64 + MinHtlcMsat int64 + LastUpdate sql.NullInt64 + Disabled sql.NullBool + MaxHtlcMsat sql.NullInt64 + InboundBaseFeeMsat sql.NullInt64 + InboundFeeRateMilliMsat sql.NullInt64 + Signature []byte +} + +func (q *Queries) UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) { + row := q.db.QueryRowContext(ctx, upsertEdgePolicy, + arg.Version, + arg.ChannelID, + arg.NodeID, + arg.Timelock, + arg.FeePpm, + arg.BaseFeeMsat, + arg.MinHtlcMsat, + arg.LastUpdate, + arg.Disabled, + arg.MaxHtlcMsat, + arg.InboundBaseFeeMsat, + arg.InboundFeeRateMilliMsat, + arg.Signature, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + const upsertNode = `-- name: UpsertNode :one /* ───────────────────────────────────────────── nodes table queries diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index b905dd4c8..e7b225204 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -16,6 +16,7 @@ type Querier interface { CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) + DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteNodeAddresses(ctx context.Context, nodeID int64) error @@ -26,6 +27,7 @@ type Querier interface { FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) + GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) @@ -49,6 +51,7 @@ type Querier interface { HighestSCID(ctx context.Context, version int16) ([]byte, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error + InsertChanPolicyExtraType(ctx context.Context, arg InsertChanPolicyExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg InsertChannelFeatureParams) error InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error @@ -74,6 +77,7 @@ type Querier interface { UpdateInvoiceHTLCs(ctx context.Context, arg UpdateInvoiceHTLCsParams) error UpdateInvoiceState(ctx context.Context, arg UpdateInvoiceStateParams) (sql.Result, error) UpsertAMPSubInvoice(ctx context.Context, arg UpsertAMPSubInvoiceParams) (sql.Result, error) + UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error } diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index a11bc7d62..85000ec8f 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -154,6 +154,17 @@ RETURNING id; SELECT * FROM channels WHERE scid = $1 AND version = $2; +-- name: GetChannelAndNodesBySCID :one +SELECT + c.*, + n1.pub_key AS node1_pub_key, + n2.pub_key AS node2_pub_key +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id +WHERE c.scid = $1 + AND c.version = $2; + -- name: HighestSCID :one SELECT scid FROM channels @@ -183,3 +194,49 @@ INSERT INTO channel_extra_types ( channel_id, type, value ) VALUES ($1, $2, $3); + +/* ───────────────────────────────────────────── + channel_policies table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertEdgePolicy :one +INSERT INTO channel_policies ( + version, channel_id, node_id, timelock, fee_ppm, + base_fee_msat, min_htlc_msat, last_update, disabled, + max_htlc_msat, inbound_base_fee_msat, + inbound_fee_rate_milli_msat, signature +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 +) +ON CONFLICT (channel_id, node_id, version) + -- Update the following fields if a conflict occurs on channel_id, + -- node_id, and version. + DO UPDATE SET + timelock = EXCLUDED.timelock, + fee_ppm = EXCLUDED.fee_ppm, + base_fee_msat = EXCLUDED.base_fee_msat, + min_htlc_msat = EXCLUDED.min_htlc_msat, + last_update = EXCLUDED.last_update, + disabled = EXCLUDED.disabled, + max_htlc_msat = EXCLUDED.max_htlc_msat, + inbound_base_fee_msat = EXCLUDED.inbound_base_fee_msat, + inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > channel_policies.last_update +RETURNING id; + +/* ───────────────────────────────────────────── + channel_policy_extra_types table queries + ───────────────────────────────────────────── +*/ + +-- name: InsertChanPolicyExtraType :exec +INSERT INTO channel_policy_extra_types ( + channel_policy_id, type, value +) +VALUES ($1, $2, $3); + +-- name: DeleteChannelPolicyExtraTypes :exec +DELETE FROM channel_policy_extra_types +WHERE channel_policy_id = $1;