sqldb: add ExecuteCollectAndBatchWithSharedDataQuery helper

In this commit we add a new helper method in the sqldb package:
ExecuteCollectAndBatchWithSharedDataQuery. This can be used to paginate
through items in the database while at the same time performing batch
data collection for those items.
This commit is contained in:
Elle Mouton
2025-07-31 12:39:13 +02:00
parent 1a60edbd33
commit 905941067e
2 changed files with 626 additions and 0 deletions

View File

@@ -157,3 +157,99 @@ func ExecutePaginatedQuery[C any, T any](ctx context.Context, cfg *QueryConfig,
return nil
}
// CollectAndBatchDataQueryFunc represents a function that batch loads
// additional data for collected identifiers, returning the batch data that
// applies to all items.
type CollectAndBatchDataQueryFunc[ID any, BatchData any] func(context.Context,
[]ID) (BatchData, error)
// ItemWithBatchDataProcessFunc represents a function that processes individual
// items along with shared batch data.
type ItemWithBatchDataProcessFunc[T any, BatchData any] func(context.Context,
T, BatchData) error
// CollectFunc represents a function that extracts an identifier from a
// paginated item.
type CollectFunc[T any, ID any] func(T) (ID, error)
// ExecuteCollectAndBatchWithSharedDataQuery implements a page-by-page
// processing pattern where each page is immediately processed with batch-loaded
// data before moving to the next page.
//
// It:
// 1. Fetches a page of items using cursor-based pagination
// 2. Collects identifiers from that page and batch loads shared data
// 3. Processes each item in the page with the shared batch data
// 4. Moves to the next page and repeats
//
// Parameters:
// - initialCursor: starting cursor for pagination
// - pageQueryFunc: fetches a page of items
// - extractPageCursor: extracts cursor from paginated item for next page
// - collectFunc: extracts identifier from paginated item
// - batchDataFunc: batch loads shared data from collected IDs for one page
// - processItem: processes each item with the shared batch data
func ExecuteCollectAndBatchWithSharedDataQuery[C any, T any, I any, D any](
ctx context.Context, cfg *QueryConfig, initialCursor C,
pageQueryFunc PagedQueryFunc[C, T],
extractPageCursor CursorExtractFunc[T, C],
collectFunc CollectFunc[T, I],
batchDataFunc CollectAndBatchDataQueryFunc[I, D],
processItem ItemWithBatchDataProcessFunc[T, D]) error {
cursor := initialCursor
for {
// Step 1: Fetch the next page of items.
items, err := pageQueryFunc(ctx, cursor, cfg.MaxPageSize)
if err != nil {
return fmt.Errorf("failed to fetch page with "+
"cursor %v: %w", cursor, err)
}
// If no items returned, we're done.
if len(items) == 0 {
break
}
// Step 2: Collect identifiers from this page and batch load
// data.
pageIDs := make([]I, len(items))
for i, item := range items {
pageIDs[i], err = collectFunc(item)
if err != nil {
return fmt.Errorf("failed to collect "+
"identifier from item: %w", err)
}
}
// Batch load shared data for this page.
batchData, err := batchDataFunc(ctx, pageIDs)
if err != nil {
return fmt.Errorf("failed to load batch data for "+
"page: %w", err)
}
// Step 3: Process each item in this page with the shared batch
// data.
for _, item := range items {
err := processItem(ctx, item, batchData)
if err != nil {
return fmt.Errorf("failed to process item "+
"with batch data: %w", err)
}
// Update cursor for next page.
cursor = extractPageCursor(item)
}
// If the number of items is less than the max page size,
// we assume there are no more items to fetch.
if len(items) < int(cfg.MaxPageSize) {
break
}
}
return nil
}

View File

@@ -558,3 +558,533 @@ func TestExecutePaginatedQuery(t *testing.T) {
})
}
}
// TestExecuteCollectAndBatchWithSharedDataQuery tests the
// ExecuteCollectAndBatchWithSharedDataQuery function which processes items in
// pages, allowing for efficient querying and processing of large datasets with
// shared data across batches.
func TestExecuteCollectAndBatchWithSharedDataQuery(t *testing.T) {
t.Parallel()
type channelRow struct {
id int64
name string
policyIDs []int64
}
type channelBatchData struct {
lookupTable map[int64]string
sharedInfo string
}
type processedChannel struct {
id int64
name string
sharedInfo string
lookupData string
}
tests := []struct {
name string
maxPageSize int32
allRows []channelRow
initialCursor int64
pageQueryError error
pageQueryErrorOnCall int
batchDataError error
batchDataErrorOnBatch int
processError error
processErrorOnID int64
earlyTerminationOnPage int
expectedError string
expectedProcessedItems []processedChannel
expectedPageCalls int
expectedBatchCalls int
}{
{
name: "multiple pages multiple batches",
maxPageSize: 2,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
policyIDs: []int64{10, 11},
},
{
id: 2,
name: "Chan2",
policyIDs: []int64{20},
},
{
id: 3,
name: "Chan3",
policyIDs: []int64{30, 31},
},
{
id: 4,
name: "Chan4",
policyIDs: []int64{40},
},
{
id: 5,
name: "Chan5",
policyIDs: []int64{50},
},
},
initialCursor: 0,
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
{
id: 2,
name: "Chan2",
sharedInfo: "batch-shared",
lookupData: "lookup-2",
},
{
id: 3,
name: "Chan3",
sharedInfo: "batch-shared",
lookupData: "lookup-3",
},
{
id: 4,
name: "Chan4",
sharedInfo: "batch-shared",
lookupData: "lookup-4",
},
{
id: 5,
name: "Chan5",
sharedInfo: "batch-shared",
lookupData: "lookup-5",
},
},
// Pages: [1,2], [3,4], [5].
expectedPageCalls: 3,
// One batch call per page with data: [1,2], [3,4], [5].
expectedBatchCalls: 3,
},
{
name: "empty results",
maxPageSize: 10,
allRows: []channelRow{},
initialCursor: 0,
// One call that returns empty.
expectedPageCalls: 1,
// No batches since no items.
expectedBatchCalls: 0,
},
{
name: "single page single batch",
maxPageSize: 10,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
policyIDs: []int64{10},
},
{
id: 2,
name: "Chan2",
policyIDs: []int64{20},
},
},
initialCursor: 0,
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
{
id: 2,
name: "Chan2",
sharedInfo: "batch-shared",
lookupData: "lookup-2",
},
},
// One page with all items.
expectedPageCalls: 1,
// One batch call for the single page.
expectedBatchCalls: 1,
},
{
name: "page query error first call",
maxPageSize: 5,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
},
initialCursor: 0,
pageQueryError: errors.New(
"database connection failed",
),
pageQueryErrorOnCall: 1,
expectedError: "failed to fetch page with " +
"cursor 0",
expectedPageCalls: 1,
expectedBatchCalls: 0,
},
{
name: "page query error second call",
maxPageSize: 1,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
pageQueryError: errors.New("database error on " +
"second page"),
pageQueryErrorOnCall: 2,
expectedError: "failed to fetch page with " +
"cursor 1",
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
},
expectedPageCalls: 2,
expectedBatchCalls: 1,
},
{
name: "batch data error first batch",
maxPageSize: 10,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
batchDataError: errors.New("batch loading " +
"failed"),
batchDataErrorOnBatch: 1,
expectedError: "failed to load batch data " +
"for page",
expectedPageCalls: 1,
expectedBatchCalls: 1,
},
{
name: "batch data error second page",
maxPageSize: 1,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
batchDataError: errors.New("batch loading " +
"failed on second page"),
batchDataErrorOnBatch: 2,
expectedError: "failed to load batch data " +
"for page",
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
},
expectedPageCalls: 2,
expectedBatchCalls: 2,
},
{
name: "process error first item",
maxPageSize: 10,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
processError: errors.New("processing failed"),
processErrorOnID: 1,
expectedError: "failed to process item with " +
"batch data",
expectedPageCalls: 1,
expectedBatchCalls: 1,
},
{
name: "process error second item",
maxPageSize: 10,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
processError: errors.New("processing failed"),
processErrorOnID: 2,
expectedError: "failed to process item with batch " +
"data",
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
},
expectedPageCalls: 1,
expectedBatchCalls: 1,
},
{
name: "early termination partial page",
maxPageSize: 3,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
},
initialCursor: 0,
earlyTerminationOnPage: 1,
expectedProcessedItems: []processedChannel{
{
id: 1,
name: "Chan1",
sharedInfo: "batch-shared",
lookupData: "lookup-1",
},
{
id: 2,
name: "Chan2",
sharedInfo: "batch-shared",
lookupData: "lookup-2",
},
},
expectedPageCalls: 1,
expectedBatchCalls: 1,
},
{
name: "different initial cursor",
maxPageSize: 2,
allRows: []channelRow{
{
id: 1,
name: "Chan1",
},
{
id: 2,
name: "Chan2",
},
{
id: 3,
name: "Chan3",
},
},
initialCursor: 1,
expectedProcessedItems: []processedChannel{
{
id: 2,
name: "Chan2",
sharedInfo: "batch-shared",
lookupData: "lookup-2",
},
{
id: 3,
name: "Chan3",
sharedInfo: "batch-shared",
lookupData: "lookup-3",
},
},
// [2,3], [].
expectedPageCalls: 2,
// One batch call for the page with [2,3].
expectedBatchCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cfg := &QueryConfig{
MaxPageSize: tt.maxPageSize,
}
var (
processedItems []processedChannel
pageCallCount int
batchCallCount int
)
pageQueryFunc := func(ctx context.Context, cursor int64,
limit int32) ([]channelRow, error) {
pageCallCount++
// Return error on specific call if configured.
//nolint:ll
if tt.pageQueryErrorOnCall > 0 &&
pageCallCount == tt.pageQueryErrorOnCall {
return nil, tt.pageQueryError
}
// Simulate cursor-based pagination.
var items []channelRow
for _, row := range tt.allRows {
if row.id > cursor &&
len(items) < int(limit) {
items = append(items, row)
}
}
// Handle early termination test case.
//nolint:ll
if tt.earlyTerminationOnPage > 0 &&
pageCallCount == tt.earlyTerminationOnPage {
// Return fewer items than maxPageSize
// to trigger termination
if len(items) >= int(tt.maxPageSize) {
items = items[:tt.maxPageSize-1]
}
}
return items, nil
}
extractPageCursor := func(row channelRow) int64 {
return row.id
}
collectFunc := func(row channelRow) (int64, error) {
return row.id, nil
}
batchDataFunc := func(ctx context.Context,
ids []int64) (*channelBatchData, error) {
batchCallCount++
// Return error on specific batch if configured.
//nolint:ll
if tt.batchDataErrorOnBatch > 0 &&
batchCallCount == tt.batchDataErrorOnBatch {
return nil, tt.batchDataError
}
// Create mock batch data.
lookupTable := make(map[int64]string)
for _, id := range ids {
lookupTable[id] =
fmt.Sprintf("lookup-%d", id)
}
return &channelBatchData{
lookupTable: lookupTable,
sharedInfo: "batch-shared",
}, nil
}
processItem := func(ctx context.Context, row channelRow,
batchData *channelBatchData) error {
// Return error on specific item if configured.
if tt.processErrorOnID > 0 &&
row.id == tt.processErrorOnID {
return tt.processError
}
processedChan := processedChannel{
id: row.id,
name: row.name,
sharedInfo: batchData.sharedInfo,
lookupData: batchData.
lookupTable[row.id],
}
processedItems = append(
processedItems, processedChan,
)
return nil
}
err := ExecuteCollectAndBatchWithSharedDataQuery(
ctx, cfg, tt.initialCursor,
pageQueryFunc, extractPageCursor, collectFunc,
batchDataFunc, processItem,
)
// Check error expectations.
if tt.expectedError != "" {
require.ErrorContains(t, err, tt.expectedError)
if tt.pageQueryError != nil {
require.ErrorIs(
t, err, tt.pageQueryError,
)
}
if tt.batchDataError != nil {
require.ErrorIs(
t, err, tt.batchDataError,
)
}
if tt.processError != nil {
require.ErrorIs(
t, err, tt.processError,
)
}
} else {
require.NoError(t, err)
}
// Check processed results.
require.Equal(
t, tt.expectedProcessedItems, processedItems,
)
// Check call counts.
require.Equal(
t, tt.expectedPageCalls, pageCallCount,
)
require.Equal(
t, tt.expectedBatchCalls, batchCallCount,
)
})
}
}