From de6c030f29b442f8ea6745c4707bb764b995ed15 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 18 Jul 2025 11:26:22 +0200 Subject: [PATCH] graph/db: let DeleteChannelEdges use new wrapped SQL call Update it to use the new wrapped version of GetChannelsBySCIDWithPolicies to reduce the number of DB calls. --- graph/db/sql_store.go | 49 +++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index a782f4b45..1c02a7b89 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1713,26 +1713,25 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, s.cacheMu.Lock() defer s.cacheMu.Unlock() + // Keep track of which channels we end up finding so that we can + // correctly return ErrEdgeNotFound if we do not find a channel. + chanLookup := make(map[uint64]struct{}, len(chanIDs)) + for _, chanID := range chanIDs { + chanLookup[chanID] = struct{}{} + } + var ( ctx = context.TODO() deleted []*models.ChannelEdgeInfo ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - for _, chanID := range chanIDs { - chanIDB := channelIDToBytes(chanID) + chanCallBack := func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error { - row, err := db.GetChannelBySCIDWithPolicies( - ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB, - Version: int16(ProtocolV1), - }, - ) - if errors.Is(err, sql.ErrNoRows) { - return ErrEdgeNotFound - } else if err != nil { - return fmt.Errorf("unable to fetch channel: %w", - err) - } + // Deleting the entry from the map indicates that we + // have found the channel. + scid := byteOrder.Uint64(row.GraphChannel.Scid) + delete(chanLookup, scid) node1, node2, err := buildNodeVertices( row.GraphNode.PubKey, row.GraphNode_2.PubKey, @@ -1758,7 +1757,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, deleted = append(deleted, info) if !markZombie { - continue + return nil } nodeKey1, nodeKey2 := info.NodeKey1Bytes, @@ -1786,7 +1785,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, err = db.UpsertZombieChannel( ctx, sqlc.UpsertZombieChannelParams{ Version: int16(ProtocolV1), - Scid: chanIDB, + Scid: channelIDToBytes(scid), NodeKey1: nodeKey1[:], NodeKey2: nodeKey2[:], }, @@ -1795,11 +1794,29 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return fmt.Errorf("unable to mark channel as "+ "zombie: %w", err) } + + return nil + } + + err := s.forEachChanWithPoliciesInSCIDList( + ctx, db, chanCallBack, chanIDs, + ) + if err != nil { + return err + } + + if len(chanLookup) > 0 { + return ErrEdgeNotFound } return nil }, func() { deleted = nil + + // Re-fill the lookup map. + for _, chanID := range chanIDs { + chanLookup[chanID] = struct{}{} + } }) if err != nil { return nil, fmt.Errorf("unable to delete channel edges: %w",