From 13bf6a549fb8f9d4a2f6b51e5da36657c32eda94 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 17:20:44 +0200 Subject: [PATCH] graph/db+sqldb: implement HasChannelEdge and ChannelID And run `TestEdgeInfoUpdates` against our SQL backends. --- graph/db/graph_test.go | 2 +- graph/db/sql_store.go | 197 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 17 +++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 4 + 5 files changed, 220 insertions(+), 1 deletion(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index b85c724c7..2650a2d6a 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -812,7 +812,7 @@ func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() ctx := context.Background() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index f825b49a6..d41239315 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -97,6 +97,7 @@ type SQLQueries interface { GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error) + GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error) DeleteChannel(ctx context.Context, id int64) error CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error @@ -1833,6 +1834,162 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( return edge, policy1, policy2, nil } +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID, and false otherwise. If an edge with that ID is found +// within the graph, then two time stamps representing the last time the edge +// was updated for both directed edges are returned along with the boolean. If +// it is not found, then the zombie index is checked and its result is returned +// as the second boolean. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, + bool, error) { + + ctx := context.TODO() + + var ( + exists bool + isZombie bool + node1LastUpdate time.Time + node2LastUpdate time.Time + ) + + // We'll query the cache with the shared lock held to allow multiple + // readers to access values in the cache concurrently if they exist. + s.cacheMu.RLock() + if entry, ok := s.rejectCache.get(chanID); ok { + s.cacheMu.RUnlock() + node1LastUpdate = time.Unix(entry.upd1Time, 0) + node2LastUpdate = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + + return node1LastUpdate, node2LastUpdate, exists, isZombie, nil + } + s.cacheMu.RUnlock() + + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + + // The item was not found with the shared lock, so we'll acquire the + // exclusive lock and check the cache again in case another method added + // the entry to the cache while no lock was held. + if entry, ok := s.rejectCache.get(chanID); ok { + node1LastUpdate = time.Unix(entry.upd1Time, 0) + node2LastUpdate = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + + return node1LastUpdate, node2LastUpdate, exists, isZombie, nil + } + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + var chanIDB [8]byte + byteOrder.PutUint64(chanIDB[:], chanID) + + channel, err := db.GetChannelBySCID( + ctx, sqlc.GetChannelBySCIDParams{ + Scid: chanIDB[:], + Version: int16(ProtocolV1), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + // Check if it is a zombie channel. + isZombie, err = db.IsZombieChannel( + ctx, sqlc.IsZombieChannelParams{ + Scid: chanIDB[:], + Version: int16(ProtocolV1), + }, + ) + if err != nil { + return fmt.Errorf("could not check if channel "+ + "is zombie: %w", err) + } + + return nil + } else if err != nil { + return fmt.Errorf("unable to fetch channel: %w", err) + } + + exists = true + + policy1, err := db.GetChannelPolicyByChannelAndNode( + ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{ + Version: int16(ProtocolV1), + ChannelID: channel.ID, + NodeID: channel.NodeID1, + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("unable to fetch channel policy: %w", + err) + } else if err == nil { + node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0) + } + + policy2, err := db.GetChannelPolicyByChannelAndNode( + ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{ + Version: int16(ProtocolV1), + ChannelID: channel.ID, + NodeID: channel.NodeID2, + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("unable to fetch channel policy: %w", + err) + } else if err == nil { + node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0) + } + + return nil + }, sqldb.NoOpReset) + if err != nil { + return time.Time{}, time.Time{}, false, false, + fmt.Errorf("unable to fetch channel: %w", err) + } + + s.rejectCache.insert(chanID, rejectCacheEntry{ + upd1Time: node1LastUpdate.Unix(), + upd2Time: node2LastUpdate.Unix(), + flags: packRejectFlags(exists, isZombie), + }) + + return node1LastUpdate, node2LastUpdate, exists, isZombie, nil +} + +// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the +// passed channel point (outpoint). If the passed channel doesn't exist within +// the database, then ErrEdgeNotFound is returned. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { + var ( + ctx = context.TODO() + channelID uint64 + ) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + chanID, err := db.GetSCIDByOutpoint( + ctx, sqlc.GetSCIDByOutpointParams{ + Outpoint: chanPoint.String(), + Version: int16(ProtocolV1), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return ErrEdgeNotFound + } else if err != nil { + return fmt.Errorf("unable to fetch channel ID: %w", + err) + } + + channelID = byteOrder.Uint64(chanID) + + return nil + }, sqldb.NoOpReset) + if err != nil { + return 0, fmt.Errorf("unable to fetch channel ID: %w", err) + } + + return channelID, nil +} + // forEachNodeDirectedChannel iterates through all channels of a given // node, executing the passed callback on the directed edge representing the // channel and its incoming policy. If the node is not found, no error is @@ -3234,6 +3391,46 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy, var policy1, policy2 *sqlc.ChannelPolicy switch r := row.(type) { + case sqlc.GetChannelByOutpointWithPoliciesRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.ChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.ChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.Channel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.GetChannelBySCIDWithPoliciesRow: if r.Policy1ID.Valid { policy1 = &sqlc.ChannelPolicy{ diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 81a52e2a2..77771052e 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1167,6 +1167,23 @@ func (q *Queries) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1 return items, nil } +const getSCIDByOutpoint = `-- name: GetSCIDByOutpoint :one +SELECT scid from channels +WHERE outpoint = $1 AND version = $2 +` + +type GetSCIDByOutpointParams struct { + Outpoint string + Version int16 +} + +func (q *Queries) GetSCIDByOutpoint(ctx context.Context, arg GetSCIDByOutpointParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getSCIDByOutpoint, arg.Outpoint, arg.Version) + var scid []byte + err := row.Scan(&scid) + return scid, err +} + const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many SELECT sn.node_id, n.pub_key FROM source_nodes sn diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 2cb044285..da417d5cd 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -58,6 +58,7 @@ type Querier interface { GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error) + GetSCIDByOutpoint(ctx context.Context, arg GetSCIDByOutpointParams) ([]byte, error) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 3ecec7676..af11fb515 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -206,6 +206,10 @@ SELECT FROM channel_extra_types cet WHERE cet.channel_id = $1; +-- name: GetSCIDByOutpoint :one +SELECT scid from channels +WHERE outpoint = $1 AND version = $2; + -- name: GetChannelsByPolicyLastUpdateRange :many SELECT sqlc.embed(c),