diff --git a/config_test_native_sql.go b/config_test_native_sql.go index b9b4debd0..92e8353bd 100644 --- a/config_test_native_sql.go +++ b/config_test_native_sql.go @@ -32,8 +32,8 @@ func (d *DefaultDatabaseBuilder) getGraphStore(baseDB *sqldb.BaseDB, return graphdb.NewSQLStore( &graphdb.SQLStoreConfig{ - ChainHash: *d.cfg.ActiveNetParams.GenesisHash, - PaginationCfg: sqldb.DefaultPagedQueryConfig(), + ChainHash: *d.cfg.ActiveNetParams.GenesisHash, + QueryCfg: sqldb.DefaultQueryConfig(), }, graphExecutor, opts..., ) diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index 18d45be58..1d9a782fc 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -59,7 +59,7 @@ var ( // testSQLPaginationCfg is used to configure the pagination settings for // the SQL stores we open for testing. - testSQLPaginationCfg = sqldb.DefaultPagedQueryConfig() + testSQLPaginationCfg = sqldb.DefaultQueryConfig() // testSqlitePragmaOpts is a set of SQLite pragma options that we apply // to the SQLite databases we open for testing. @@ -277,8 +277,8 @@ func newSQLExecutor(t testing.TB, db sqldb.DB) BatchedSQLQueries { func newSQLStore(t testing.TB, db BatchedSQLQueries) V1Store { store, err := NewSQLStore( &SQLStoreConfig{ - ChainHash: dbTestChain, - PaginationCfg: testSQLPaginationCfg, + ChainHash: dbTestChain, + QueryCfg: testSQLPaginationCfg, }, db, testStoreOptions..., ) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index f996bfbfa..006085aa5 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -193,8 +193,8 @@ type SQLStoreConfig struct { // messages in this store are aimed at. ChainHash chainhash.Hash - // PaginationCfg is the configuration for paginated queries. - PaginationCfg *sqldb.PagedQueryConfig + // QueryConfig holds configuration values for SQL queries. + QueryCfg *sqldb.QueryConfig } // NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries @@ -559,7 +559,7 @@ func (s *SQLStore) NodeUpdatesInHorizon(startTime, } err = forEachNodeInBatch( - ctx, s.cfg.PaginationCfg, db, dbNodes, + ctx, s.cfg.QueryCfg, db, dbNodes, func(_ int64, node *models.LightningNode) error { nodes = append(nodes, *node) @@ -847,7 +847,7 @@ func (s *SQLStore) ForEachNode(ctx context.Context, } err = forEachNodeInBatch( - ctx, s.cfg.PaginationCfg, db, nodes, nodeCB, + ctx, s.cfg.QueryCfg, db, nodes, nodeCB, ) if err != nil { return fmt.Errorf("unable to iterate over "+ @@ -1447,7 +1447,7 @@ func (s *SQLStore) ForEachChannel(ctx context.Context, } batchData, err := batchLoadChannelData( - ctx, s.cfg.PaginationCfg, db, channelIDs, + ctx, s.cfg.QueryCfg, db, channelIDs, policyIDs, ) if err != nil { @@ -2330,9 +2330,9 @@ func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context, ) } - return sqldb.ExecutePagedQuery( - ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes, - queryWrapper, cb, + return sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper, + cb, ) } @@ -2453,9 +2453,9 @@ func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, return channelIDToBytes(channelID) } - return sqldb.ExecutePagedQuery( - ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter, - queryWrapper, cb, + return sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper, + cb, ) } @@ -2612,8 +2612,8 @@ func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries, return outpoint.String() } - return sqldb.ExecutePagedQuery( - ctx, s.cfg.PaginationCfg, outpoints, outpointToString, + return sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, outpoints, outpointToString, queryWrapper, cb, ) } @@ -2631,8 +2631,8 @@ func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries, return id } - return sqldb.ExecutePagedQuery( - ctx, s.cfg.PaginationCfg, dbIDs, idConverter, + return sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, dbIDs, idConverter, queryWrapper, func(ctx context.Context, _ any) error { return nil }, @@ -3391,7 +3391,7 @@ func buildNode(ctx context.Context, db SQLQueries, // NOTE: buildNode is only used to load the data for a single node, and // so no paged queries will be performed. This means that it's ok to // used pass in default config values here. - cfg := sqldb.DefaultPagedQueryConfig() + cfg := sqldb.DefaultQueryConfig() data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID}) if err != nil { @@ -3477,7 +3477,7 @@ func buildNodeWithBatchData(dbNode *sqlc.GraphNode, // forEachNodeInBatch fetches all nodes in the provided batch, builds them // with the preloaded data, and executes the provided callback for each node. -func forEachNodeInBatch(ctx context.Context, cfg *sqldb.PagedQueryConfig, +func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries, nodes []sqlc.GraphNode, cb func(dbID int64, node *models.LightningNode) error) error { @@ -4127,7 +4127,7 @@ func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries, // NOTE: getAndBuildEdgeInfo is only used to load the data for a single // edge, and so no paged queries will be performed. This means that // it's ok to used pass in default config values here. - cfg := sqldb.DefaultPagedQueryConfig() + cfg := sqldb.DefaultQueryConfig() data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil) if err != nil { @@ -4255,7 +4255,7 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries, // a maximum of two policies, and so no paged queries will be // performed (unless the page size is one). So it's ok to use // the default config values here. - cfg := sqldb.DefaultPagedQueryConfig() + cfg := sqldb.DefaultQueryConfig() batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs) if err != nil { @@ -4779,7 +4779,7 @@ type nodeAddress struct { // batchLoadNodeData loads all related data for a batch of node IDs using the // provided SQLQueries interface. It returns a batchNodeData instance containing // the node features, addresses and extra signed fields. -func batchLoadNodeData(ctx context.Context, cfg *sqldb.PagedQueryConfig, +func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries, nodeIDs []int64) (*batchNodeData, error) { // Batch load the node features. @@ -4811,14 +4811,14 @@ func batchLoadNodeData(ctx context.Context, cfg *sqldb.PagedQueryConfig, } // batchLoadNodeFeaturesHelper loads node features for a batch of node IDs -// using ExecutePagedQuery wrapper around the GetNodeFeaturesBatch query. +// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query. func batchLoadNodeFeaturesHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, nodeIDs []int64) (map[int64][]int, error) { features := make(map[int64][]int) - return features, sqldb.ExecutePagedQuery( + return features, sqldb.ExecuteBatchQuery( ctx, cfg, nodeIDs, func(id int64) int64 { return id @@ -4839,16 +4839,16 @@ func batchLoadNodeFeaturesHelper(ctx context.Context, ) } -// batchLoadNodeAddressesHelper loads node addresses using ExecutePagedQuery +// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery // wrapper around the GetNodeAddressesBatch query. It returns a map from // node ID to a slice of nodeAddress structs. func batchLoadNodeAddressesHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, nodeIDs []int64) (map[int64][]nodeAddress, error) { addrs := make(map[int64][]nodeAddress) - return addrs, sqldb.ExecutePagedQuery( + return addrs, sqldb.ExecuteBatchQuery( ctx, cfg, nodeIDs, func(id int64) int64 { return id @@ -4873,10 +4873,10 @@ func batchLoadNodeAddressesHelper(ctx context.Context, } // batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of -// node IDs using ExecutePagedQuery wrapper around the GetNodeExtraTypesBatch +// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch // query. func batchLoadNodeExtraTypesHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, nodeIDs []int64) (map[int64]map[uint64][]byte, error) { extraFields := make(map[int64]map[uint64][]byte) @@ -4892,7 +4892,7 @@ func batchLoadNodeExtraTypesHelper(ctx context.Context, return nil } - return extraFields, sqldb.ExecutePagedQuery( + return extraFields, sqldb.ExecuteBatchQuery( ctx, cfg, nodeIDs, func(id int64) int64 { return id @@ -4967,7 +4967,7 @@ type batchChannelData struct { // batchLoadChannelData loads all related data for batches of channels and // policies. -func batchLoadChannelData(ctx context.Context, cfg *sqldb.PagedQueryConfig, +func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries, channelIDs []int64, policyIDs []int64) (*batchChannelData, error) { @@ -5012,16 +5012,16 @@ func batchLoadChannelData(ctx context.Context, cfg *sqldb.PagedQueryConfig, } // batchLoadChannelFeaturesHelper loads channel features for a batch of -// channel IDs using ExecutePagedQuery wrapper around the +// channel IDs using ExecuteBatchQuery wrapper around the // GetChannelFeaturesBatch query. It returns a map from DB channel ID to a // slice of feature bits. func batchLoadChannelFeaturesHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, channelIDs []int64) (map[int64][]int, error) { features := make(map[int64][]int) - return features, sqldb.ExecutePagedQuery( + return features, sqldb.ExecuteBatchQuery( ctx, cfg, channelIDs, func(id int64) int64 { return id @@ -5045,11 +5045,11 @@ func batchLoadChannelFeaturesHelper(ctx context.Context, } // batchLoadChannelExtrasHelper loads channel extra types for a batch of -// channel IDs using ExecutePagedQuery wrapper around the GetChannelExtrasBatch +// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch // query. It returns a map from DB channel ID to a map of TLV type to extra // signed field bytes. func batchLoadChannelExtrasHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, channelIDs []int64) (map[int64]map[uint64][]byte, error) { extras := make(map[int64]map[uint64][]byte) @@ -5065,7 +5065,7 @@ func batchLoadChannelExtrasHelper(ctx context.Context, return nil } - return extras, sqldb.ExecutePagedQuery( + return extras, sqldb.ExecuteBatchQuery( ctx, cfg, channelIDs, func(id int64) int64 { return id @@ -5079,16 +5079,16 @@ func batchLoadChannelExtrasHelper(ctx context.Context, } // batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a -// batch of policy IDs using ExecutePagedQuery wrapper around the +// batch of policy IDs using ExecuteBatchQuery wrapper around the // GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to // a map of TLV type to extra signed field bytes. func batchLoadChannelPolicyExtrasHelper(ctx context.Context, - cfg *sqldb.PagedQueryConfig, db SQLQueries, + cfg *sqldb.QueryConfig, db SQLQueries, policyIDs []int64) (map[int64]map[uint64][]byte, error) { extras := make(map[int64]map[uint64][]byte) - return extras, sqldb.ExecutePagedQuery( + return extras, sqldb.ExecuteBatchQuery( ctx, cfg, policyIDs, func(id int64) int64 { return id diff --git a/graph/db/test_postgres.go b/graph/db/test_postgres.go index a812f998f..463a89db2 100644 --- a/graph/db/test_postgres.go +++ b/graph/db/test_postgres.go @@ -42,8 +42,8 @@ func NewTestDBWithFixture(t testing.TB, store, err := NewSQLStore( &SQLStoreConfig{ - ChainHash: *chaincfg.MainNetParams.GenesisHash, - PaginationCfg: sqldb.DefaultPagedQueryConfig(), + ChainHash: *chaincfg.MainNetParams.GenesisHash, + QueryCfg: sqldb.DefaultQueryConfig(), }, querier, ) require.NoError(t, err) diff --git a/graph/db/test_sqlite.go b/graph/db/test_sqlite.go index 41773c66e..3442bfd6e 100644 --- a/graph/db/test_sqlite.go +++ b/graph/db/test_sqlite.go @@ -27,8 +27,8 @@ func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture) V1Store { store, err := NewSQLStore( &SQLStoreConfig{ - ChainHash: *chaincfg.MainNetParams.GenesisHash, - PaginationCfg: sqldb.DefaultPagedQueryConfig(), + ChainHash: *chaincfg.MainNetParams.GenesisHash, + QueryCfg: sqldb.DefaultQueryConfig(), }, newBatchQuerier(t), ) require.NoError(t, err) diff --git a/itest/lnd_graph_migration_test.go b/itest/lnd_graph_migration_test.go index 04745da49..fe3e6b175 100644 --- a/itest/lnd_graph_migration_test.go +++ b/itest/lnd_graph_migration_test.go @@ -144,8 +144,8 @@ func openNativeSQLGraphDB(ht *lntest.HarnessTest, store, err := graphdb.NewSQLStore( &graphdb.SQLStoreConfig{ - ChainHash: *ht.Miner().ActiveNet.GenesisHash, - PaginationCfg: sqldb.DefaultPagedQueryConfig(), + ChainHash: *ht.Miner().ActiveNet.GenesisHash, + QueryCfg: sqldb.DefaultQueryConfig(), }, executor, ) diff --git a/sqldb/paginate.go b/sqldb/paginate.go index e0b382b40..281c4651b 100644 --- a/sqldb/paginate.go +++ b/sqldb/paginate.go @@ -5,46 +5,56 @@ import ( "fmt" ) -// PagedQueryFunc represents a function that takes a slice of converted items +// QueryConfig holds configuration values for SQL queries. +type QueryConfig struct { + // MaxBatchSize is the maximum number of items included in a batch + // query IN clauses list. + MaxBatchSize int +} + +// DefaultQueryConfig returns a default configuration for SQL queries. +func DefaultQueryConfig() *QueryConfig { + return &QueryConfig{ + // TODO(elle): make configurable & have different defaults + // for SQLite and Postgres. + MaxBatchSize: 250, + } +} + +// BatchQueryFunc represents a function that takes a batch of converted items // and returns results. -type PagedQueryFunc[T any, R any] func(context.Context, []T) ([]R, error) +type BatchQueryFunc[T any, R any] func(context.Context, []T) ([]R, error) // ItemCallbackFunc represents a function that processes individual results. type ItemCallbackFunc[R any] func(context.Context, R) error // ConvertFunc represents a function that converts from input type to query type +// for the batch query. type ConvertFunc[I any, T any] func(I) T -// PagedQueryConfig holds configuration values for calls to ExecutePagedQuery. -type PagedQueryConfig struct { - PageSize int -} - -// DefaultPagedQueryConfig returns a default configuration -func DefaultPagedQueryConfig() *PagedQueryConfig { - return &PagedQueryConfig{ - // TODO(elle): make configurable & have different defaults - // for SQLite and Postgres. - PageSize: 250, - } -} - -// ExecutePagedQuery executes a paginated query over a slice of input items. +// ExecuteBatchQuery executes a query in batches over a slice of input items. // It converts the input items to a query type using the provided convertFunc, -// executes the query using the provided queryFunc, and applies the callback -// to each result. -func ExecutePagedQuery[I any, T any, R any](ctx context.Context, - cfg *PagedQueryConfig, inputItems []I, convertFunc ConvertFunc[I, T], - queryFunc PagedQueryFunc[T, R], callback ItemCallbackFunc[R]) error { +// executes the query in batches using the provided queryFunc, and applies +// the callback to each result. This is useful for queries using the +// "WHERE x IN []slice" pattern. It takes that slice, splits it into batches of +// size MaxBatchSize, and executes the query for each batch. +// +// NOTE: it is the caller's responsibility to ensure that the expected return +// results are unique across all pages. Meaning that if the input items are +// split up, a result that is returned in one page should not be expected to +// be returned in another page. +func ExecuteBatchQuery[I any, T any, R any](ctx context.Context, + cfg *QueryConfig, inputItems []I, convertFunc ConvertFunc[I, T], + queryFunc BatchQueryFunc[T, R], callback ItemCallbackFunc[R]) error { if len(inputItems) == 0 { return nil } // Process items in pages. - for i := 0; i < len(inputItems); i += cfg.PageSize { + for i := 0; i < len(inputItems); i += cfg.MaxBatchSize { // Calculate the end index for this page. - end := i + cfg.PageSize + end := i + cfg.MaxBatchSize if end > len(inputItems) { end = len(inputItems) } diff --git a/sqldb/paginate_test.go b/sqldb/paginate_test.go index 0ebf25371..879f292e1 100644 --- a/sqldb/paginate_test.go +++ b/sqldb/paginate_test.go @@ -12,17 +12,17 @@ import ( "github.com/stretchr/testify/require" ) -// TestExecutePagedQuery tests the ExecutePagedQuery function which processes +// TestExecuteBatchQuery tests the ExecuteBatchQuery function which processes // items in pages, allowing for efficient querying and processing of large // datasets. -func TestExecutePagedQuery(t *testing.T) { +func TestExecuteBatchQuery(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("empty input returns nil", func(t *testing.T) { var ( - cfg = DefaultPagedQueryConfig() + cfg = DefaultQueryConfig() inputItems []int ) @@ -44,7 +44,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.NoError(t, err) @@ -55,8 +55,8 @@ func TestExecutePagedQuery(t *testing.T) { convertedItems []string callbackResults []string inputItems = []int{1, 2, 3, 4, 5} - cfg = &PagedQueryConfig{ - PageSize: 10, + cfg = &QueryConfig{ + MaxBatchSize: 10, } ) @@ -81,7 +81,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.NoError(t, err) @@ -104,8 +104,8 @@ func TestExecutePagedQuery(t *testing.T) { pageSizes []int allResults []string inputItems = []int{1, 2, 3, 4, 5, 6, 7, 8} - cfg = &PagedQueryConfig{ - PageSize: 3, + cfg = &QueryConfig{ + MaxBatchSize: 3, } ) @@ -131,7 +131,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestExecutePagedQuery(t *testing.T) { t.Run("query function error is propagated", func(t *testing.T) { var ( - cfg = DefaultPagedQueryConfig() + cfg = DefaultQueryConfig() inputItems = []int{1, 2, 3} ) @@ -165,7 +165,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.ErrorContains(t, err, "query failed for page "+ @@ -174,7 +174,7 @@ func TestExecutePagedQuery(t *testing.T) { t.Run("callback error is propagated", func(t *testing.T) { var ( - cfg = DefaultPagedQueryConfig() + cfg = DefaultQueryConfig() inputItems = []int{1, 2, 3} ) @@ -195,7 +195,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.ErrorContains(t, err, "callback failed for result: "+ @@ -205,8 +205,8 @@ func TestExecutePagedQuery(t *testing.T) { t.Run("query error in second page is propagated", func(t *testing.T) { var ( inputItems = []int{1, 2, 3, 4} - cfg = &PagedQueryConfig{ - PageSize: 2, + cfg = &QueryConfig{ + MaxBatchSize: 2, } queryCallCount int ) @@ -230,7 +230,7 @@ func TestExecutePagedQuery(t *testing.T) { return nil } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, cfg, inputItems, convertFunc, queryFunc, callback, ) require.ErrorContains(t, err, "query failed for page "+ @@ -238,10 +238,10 @@ func TestExecutePagedQuery(t *testing.T) { }) } -// TestSQLSliceQueries tests ExecutePageQuery helper by first showing that a +// TestSQLSliceQueries tests ExecuteBatchQuery helper by first showing that a // query the /*SLICE:*/ directive has a maximum number of // parameters it can handle, and then showing that the paginated version which -// uses ExecutePagedQuery instead of a raw query can handle more parameters by +// uses ExecuteBatchQuery instead of a raw query can handle more parameters by // executing the query in pages. func TestSQLSliceQueries(t *testing.T) { t.Parallel() @@ -299,9 +299,9 @@ func TestSQLSliceQueries(t *testing.T) { return db.GetChannelsByOutpoints(ctx, pageOutpoints) } - err := ExecutePagedQuery( + err := ExecuteBatchQuery( ctx, - DefaultPagedQueryConfig(), + DefaultQueryConfig(), queryParams, func(s string) string { return s