diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index cc8b199bc..5e9305489 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -93,6 +93,7 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error) + GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error) GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) @@ -2259,31 +2260,49 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, ctx = context.TODO() newChanIDs []uint64 knownZombies []ChannelUpdateInfo + infoLookup = make( + map[uint64]ChannelUpdateInfo, len(chansInfo), + ) ) + + // We first build a lookup map of the channel ID's to the + // ChannelUpdateInfo. This allows us to quickly delete channels that we + // already know about. + for _, chanInfo := range chansInfo { + infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo + } + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // The call-back function deletes known channels from + // infoLookup, so that we can later check which channels are + // zombies by only looking at the remaining channels in the set. + cb := func(ctx context.Context, + channel sqlc.GraphChannel) error { + + delete(infoLookup, byteOrder.Uint64(channel.Scid)) + + return nil + } + + err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo) + if err != nil { + return fmt.Errorf("unable to iterate through "+ + "channels: %w", err) + } + + // We want to ensure that we deal with the channels in the + // same order that they were passed in, so we iterate over the + // original chansInfo slice and then check if that channel is + // still in the infoLookup map. for _, chanInfo := range chansInfo { channelID := chanInfo.ShortChannelID.ToUint64() - chanIDB := channelIDToBytes(channelID) - - // TODO(elle): potentially optimize this by using - // sqlc.slice() once that works for both SQLite and - // Postgres. - _, err := db.GetChannelBySCID( - ctx, sqlc.GetChannelBySCIDParams{ - Version: int16(ProtocolV1), - Scid: chanIDB, - }, - ) - if err == nil { + if _, ok := infoLookup[channelID]; !ok { continue - } else if !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("unable to fetch channel: %w", - err) } isZombie, err := db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ - Scid: chanIDB, + Scid: channelIDToBytes(channelID), Version: int16(ProtocolV1), }, ) @@ -2305,6 +2324,11 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, }, func() { newChanIDs = nil knownZombies = nil + // Rebuild the infoLookup map in case of a rollback. + for _, chanInfo := range chansInfo { + scid := chanInfo.ShortChannelID.ToUint64() + infoLookup[scid] = chanInfo + } }) if err != nil { return nil, nil, fmt.Errorf("unable to fetch channels: %w", err) @@ -2313,6 +2337,37 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, return newChanIDs, knownZombies, nil } +// forEachChanInSCIDList is a helper method that executes a paged query +// against the database to fetch all channels that match the passed +// ChannelUpdateInfo slice. The callback function is called for each channel +// that is found. +func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, + cb func(ctx context.Context, channel sqlc.GraphChannel) error, + chansInfo []ChannelUpdateInfo) error { + + queryWrapper := func(ctx context.Context, + scids [][]byte) ([]sqlc.GraphChannel, error) { + + return db.GetChannelsBySCIDs( + ctx, sqlc.GetChannelsBySCIDsParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + } + + chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte { + channelID := chanInfo.ShortChannelID.ToUint64() + + return channelIDToBytes(channelID) + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter, + queryWrapper, cb, + ) +} + // PruneGraphNodes is a garbage collection method which attempts to prune out // any nodes from the channel graph that are currently unconnected. This ensure // that we only maintain a graph of reachable nodes. In the event that a pruned diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 755aea5f3..09df90264 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1154,6 +1154,65 @@ func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsByS return items, nil } +const getChannelsBySCIDs = `-- name: GetChannelsBySCIDs :many +SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM graph_channels +WHERE version = $1 + AND scid IN (/*SLICE:scids*/?) +` + +type GetChannelsBySCIDsParams struct { + Version int16 + Scids [][]byte +} + +func (q *Queries) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) { + query := getChannelsBySCIDs + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphChannel + for rows.Next() { + var i GraphChannel + if err := rows.Scan( + &i.ID, + &i.Version, + &i.Scid, + &i.NodeID1, + &i.NodeID2, + &i.Outpoint, + &i.Capacity, + &i.BitcoinKey1, + &i.BitcoinKey2, + &i.Node1Signature, + &i.Node2Signature, + &i.Bitcoin1Signature, + &i.Bitcoin2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM graph_node_extra_types diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 786168bb5..01d45734a 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -44,6 +44,7 @@ type Querier interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) + GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) // This method may return more than one invoice if filter using multiple fields diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index f86197749..5216ddea5 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -231,6 +231,11 @@ WHERE scid >= @start_scid SELECT * FROM graph_channels WHERE scid = $1 AND version = $2; +-- name: GetChannelsBySCIDs :many +SELECT * FROM graph_channels +WHERE version = @version + AND scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: GetChannelsByOutpoints :many SELECT sqlc.embed(c),