graph/db: let DeleteChannelEdges use new wrapped SQL call

Update it to use the new wrapped version of
GetChannelsBySCIDWithPolicies to reduce the number of DB calls.
This commit is contained in:
Elle Mouton
2025-07-18 11:26:22 +02:00
parent e269d57ffa
commit de6c030f29

View File

@@ -1713,26 +1713,25 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
s.cacheMu.Lock() s.cacheMu.Lock()
defer s.cacheMu.Unlock() defer s.cacheMu.Unlock()
// Keep track of which channels we end up finding so that we can
// correctly return ErrEdgeNotFound if we do not find a channel.
chanLookup := make(map[uint64]struct{}, len(chanIDs))
for _, chanID := range chanIDs {
chanLookup[chanID] = struct{}{}
}
var ( var (
ctx = context.TODO() ctx = context.TODO()
deleted []*models.ChannelEdgeInfo deleted []*models.ChannelEdgeInfo
) )
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
for _, chanID := range chanIDs { chanCallBack := func(ctx context.Context,
chanIDB := channelIDToBytes(chanID) row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
row, err := db.GetChannelBySCIDWithPolicies( // Deleting the entry from the map indicates that we
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ // have found the channel.
Scid: chanIDB, scid := byteOrder.Uint64(row.GraphChannel.Scid)
Version: int16(ProtocolV1), delete(chanLookup, scid)
},
)
if errors.Is(err, sql.ErrNoRows) {
return ErrEdgeNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w",
err)
}
node1, node2, err := buildNodeVertices( node1, node2, err := buildNodeVertices(
row.GraphNode.PubKey, row.GraphNode_2.PubKey, row.GraphNode.PubKey, row.GraphNode_2.PubKey,
@@ -1758,7 +1757,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
deleted = append(deleted, info) deleted = append(deleted, info)
if !markZombie { if !markZombie {
continue return nil
} }
nodeKey1, nodeKey2 := info.NodeKey1Bytes, nodeKey1, nodeKey2 := info.NodeKey1Bytes,
@@ -1786,7 +1785,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
err = db.UpsertZombieChannel( err = db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{ ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
Scid: chanIDB, Scid: channelIDToBytes(scid),
NodeKey1: nodeKey1[:], NodeKey1: nodeKey1[:],
NodeKey2: nodeKey2[:], NodeKey2: nodeKey2[:],
}, },
@@ -1795,11 +1794,29 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
return fmt.Errorf("unable to mark channel as "+ return fmt.Errorf("unable to mark channel as "+
"zombie: %w", err) "zombie: %w", err)
} }
return nil
}
err := s.forEachChanWithPoliciesInSCIDList(
ctx, db, chanCallBack, chanIDs,
)
if err != nil {
return err
}
if len(chanLookup) > 0 {
return ErrEdgeNotFound
} }
return nil return nil
}, func() { }, func() {
deleted = nil deleted = nil
// Re-fill the lookup map.
for _, chanID := range chanIDs {
chanLookup[chanID] = struct{}{}
}
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to delete channel edges: %w", return nil, fmt.Errorf("unable to delete channel edges: %w",