graph/db+sqldb: implement HasChannelEdge and ChannelID

And run `TestEdgeInfoUpdates` against our SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 17:20:44 +02:00
parent 4fad4a7023
commit 13bf6a549f
5 changed files with 220 additions and 1 deletions

View File

@@ -97,6 +97,7 @@ type SQLQueries interface {
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
DeleteChannel(ctx context.Context, id int64) error
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
@@ -1833,6 +1834,162 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
return edge, policy1, policy2, nil
}
// HasChannelEdge returns true if the database knows of a channel edge with the
// passed channel ID, and false otherwise. If an edge with that ID is found
// within the graph, then two time stamps representing the last time the edge
// was updated for both directed edges are returned along with the boolean. If
// it is not found, then the zombie index is checked and its result is returned
// as the second boolean.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
bool, error) {
ctx := context.TODO()
var (
exists bool
isZombie bool
node1LastUpdate time.Time
node2LastUpdate time.Time
)
// We'll query the cache with the shared lock held to allow multiple
// readers to access values in the cache concurrently if they exist.
s.cacheMu.RLock()
if entry, ok := s.rejectCache.get(chanID); ok {
s.cacheMu.RUnlock()
node1LastUpdate = time.Unix(entry.upd1Time, 0)
node2LastUpdate = time.Unix(entry.upd2Time, 0)
exists, isZombie = entry.flags.unpack()
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
}
s.cacheMu.RUnlock()
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
// The item was not found with the shared lock, so we'll acquire the
// exclusive lock and check the cache again in case another method added
// the entry to the cache while no lock was held.
if entry, ok := s.rejectCache.get(chanID); ok {
node1LastUpdate = time.Unix(entry.upd1Time, 0)
node2LastUpdate = time.Unix(entry.upd2Time, 0)
exists, isZombie = entry.flags.unpack()
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
}
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
channel, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
// Check if it is a zombie channel.
isZombie, err = db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if err != nil {
return fmt.Errorf("could not check if channel "+
"is zombie: %w", err)
}
return nil
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w", err)
}
exists = true
policy1, err := db.GetChannelPolicyByChannelAndNode(
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
Version: int16(ProtocolV1),
ChannelID: channel.ID,
NodeID: channel.NodeID1,
},
)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch channel policy: %w",
err)
} else if err == nil {
node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
}
policy2, err := db.GetChannelPolicyByChannelAndNode(
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
Version: int16(ProtocolV1),
ChannelID: channel.ID,
NodeID: channel.NodeID2,
},
)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch channel policy: %w",
err)
} else if err == nil {
node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return time.Time{}, time.Time{}, false, false,
fmt.Errorf("unable to fetch channel: %w", err)
}
s.rejectCache.insert(chanID, rejectCacheEntry{
upd1Time: node1LastUpdate.Unix(),
upd2Time: node2LastUpdate.Unix(),
flags: packRejectFlags(exists, isZombie),
})
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
}
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
// passed channel point (outpoint). If the passed channel doesn't exist within
// the database, then ErrEdgeNotFound is returned.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var (
ctx = context.TODO()
channelID uint64
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
chanID, err := db.GetSCIDByOutpoint(
ctx, sqlc.GetSCIDByOutpointParams{
Outpoint: chanPoint.String(),
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
return ErrEdgeNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch channel ID: %w",
err)
}
channelID = byteOrder.Uint64(chanID)
return nil
}, sqldb.NoOpReset)
if err != nil {
return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
}
return channelID, nil
}
// forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is
@@ -3234,6 +3391,46 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
var policy1, policy2 *sqlc.ChannelPolicy
switch r := row.(type) {
case sqlc.GetChannelByOutpointWithPoliciesRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.ChannelPolicy{
ID: r.Policy1ID.Int64,
Version: r.Policy1Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy1NodeID.Int64,
Timelock: r.Policy1Timelock.Int32,
FeePpm: r.Policy1FeePpm.Int64,
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
LastUpdate: r.Policy1LastUpdate,
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
Disabled: r.Policy1Disabled,
Signature: r.Policy1Signature,
}
}
if r.Policy2ID.Valid {
policy2 = &sqlc.ChannelPolicy{
ID: r.Policy2ID.Int64,
Version: r.Policy2Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy2NodeID.Int64,
Timelock: r.Policy2Timelock.Int32,
FeePpm: r.Policy2FeePpm.Int64,
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
LastUpdate: r.Policy2LastUpdate,
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
Disabled: r.Policy2Disabled,
Signature: r.Policy2Signature,
}
}
return policy1, policy2, nil
case sqlc.GetChannelBySCIDWithPoliciesRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.ChannelPolicy{