diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index c828cbc53..7dfd0b7f7 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1275,26 +1275,11 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo, return err } - var pol1, pol2 *models.CachedEdgePolicy - if dbPol1 != nil { - policy1, err := buildChanPolicy( - *dbPol1, edge.ChannelID, nil, node2, - ) - if err != nil { - return err - } - - pol1 = models.NewCachedPolicy(policy1) - } - if dbPol2 != nil { - policy2, err := buildChanPolicy( - *dbPol2, edge.ChannelID, nil, node1, - ) - if err != nil { - return err - } - - pol2 = models.NewCachedPolicy(policy2) + pol1, pol2, err := buildCachedChanPolicies( + dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return err } if err := cb(edge, pol1, pol2); err != nil { @@ -2646,58 +2631,50 @@ func (s *SQLStore) ChannelView() ([]EdgePoint, error) { edgePoints []EdgePoint ) - handleChannel := func(db SQLQueries, - channel sqlc.ListChannelsPaginatedRow) error { - - pkScript, err := genMultiSigP2WSH( - channel.BitcoinKey1, channel.BitcoinKey2, - ) - if err != nil { - return err - } - - op, err := wire.NewOutPointFromString(channel.Outpoint) - if err != nil { - return err - } - - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: *op, - }) - - return nil - } - err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - lastID := int64(-1) - for { - rows, err := db.ListChannelsPaginated( - ctx, sqlc.ListChannelsPaginatedParams{ - Version: int16(ProtocolV1), - ID: lastID, - Limit: s.cfg.QueryCfg.MaxPageSize, - }, + handleChannel := func(_ context.Context, + channel sqlc.ListChannelsPaginatedRow) error { + + pkScript, err := genMultiSigP2WSH( + channel.BitcoinKey1, channel.BitcoinKey2, ) if err != nil { return err } - if len(rows) == 0 { - break + op, err := wire.NewOutPointFromString(channel.Outpoint) + if err != nil { + return err } - for _, row := range rows { - err := handleChannel(db, row) - if err != nil { - return err - } + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: *op, + }) - lastID = row.ID - } + return nil } - return nil + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.ListChannelsPaginatedRow, error) { + + return db.ListChannelsPaginated( + ctx, sqlc.ListChannelsPaginatedParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: limit, + }, + ) + } + + extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 { + return row.ID + } + + return sqldb.ExecutePaginatedQuery( + ctx, s.cfg.QueryCfg, int64(-1), queryFunc, + extractCursor, handleChannel, + ) }, func() { edgePoints = nil }) @@ -3071,26 +3048,11 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries, return err } - var p1, p2 *models.CachedEdgePolicy - if dbPol1 != nil { - policy1, err := buildChanPolicy( - *dbPol1, edge.ChannelID, nil, node2, - ) - if err != nil { - return err - } - - p1 = models.NewCachedPolicy(policy1) - } - if dbPol2 != nil { - policy2, err := buildChanPolicy( - *dbPol2, edge.ChannelID, nil, node1, - ) - if err != nil { - return err - } - - p2 = models.NewCachedPolicy(policy2) + p1, p2, err := buildCachedChanPolicies( + dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return err } // Determine the outgoing and incoming policy for this @@ -4276,6 +4238,34 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries, return pol1, pol2, nil } +// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the +// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil, +// then nil is returned for it. +func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, + channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy, + *models.CachedEdgePolicy, error) { + + var p1, p2 *models.CachedEdgePolicy + if dbPol1 != nil { + policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2) + if err != nil { + return nil, nil, err + } + + p1 = models.NewCachedPolicy(policy1) + } + if dbPol2 != nil { + policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1) + if err != nil { + return nil, nil, err + } + + p2 = models.NewCachedPolicy(policy2) + } + + return p1, p2, nil +} + // buildChanPolicy builds a models.ChannelEdgePolicy instance from the // provided sqlc.GraphChannelPolicy and other required information. func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,