diff --git a/sqldb/paginate.go b/sqldb/paginate.go index 4d8caddb6..2a51218c5 100644 --- a/sqldb/paginate.go +++ b/sqldb/paginate.go @@ -157,3 +157,99 @@ func ExecutePaginatedQuery[C any, T any](ctx context.Context, cfg *QueryConfig, return nil } + +// CollectAndBatchDataQueryFunc represents a function that batch loads +// additional data for collected identifiers, returning the batch data that +// applies to all items. +type CollectAndBatchDataQueryFunc[ID any, BatchData any] func(context.Context, + []ID) (BatchData, error) + +// ItemWithBatchDataProcessFunc represents a function that processes individual +// items along with shared batch data. +type ItemWithBatchDataProcessFunc[T any, BatchData any] func(context.Context, + T, BatchData) error + +// CollectFunc represents a function that extracts an identifier from a +// paginated item. +type CollectFunc[T any, ID any] func(T) (ID, error) + +// ExecuteCollectAndBatchWithSharedDataQuery implements a page-by-page +// processing pattern where each page is immediately processed with batch-loaded +// data before moving to the next page. +// +// It: +// 1. Fetches a page of items using cursor-based pagination +// 2. Collects identifiers from that page and batch loads shared data +// 3. Processes each item in the page with the shared batch data +// 4. Moves to the next page and repeats +// +// Parameters: +// - initialCursor: starting cursor for pagination +// - pageQueryFunc: fetches a page of items +// - extractPageCursor: extracts cursor from paginated item for next page +// - collectFunc: extracts identifier from paginated item +// - batchDataFunc: batch loads shared data from collected IDs for one page +// - processItem: processes each item with the shared batch data +func ExecuteCollectAndBatchWithSharedDataQuery[C any, T any, I any, D any]( + ctx context.Context, cfg *QueryConfig, initialCursor C, + pageQueryFunc PagedQueryFunc[C, T], + extractPageCursor CursorExtractFunc[T, C], + collectFunc CollectFunc[T, I], + batchDataFunc CollectAndBatchDataQueryFunc[I, D], + processItem ItemWithBatchDataProcessFunc[T, D]) error { + + cursor := initialCursor + + for { + // Step 1: Fetch the next page of items. + items, err := pageQueryFunc(ctx, cursor, cfg.MaxPageSize) + if err != nil { + return fmt.Errorf("failed to fetch page with "+ + "cursor %v: %w", cursor, err) + } + + // If no items returned, we're done. + if len(items) == 0 { + break + } + + // Step 2: Collect identifiers from this page and batch load + // data. + pageIDs := make([]I, len(items)) + for i, item := range items { + pageIDs[i], err = collectFunc(item) + if err != nil { + return fmt.Errorf("failed to collect "+ + "identifier from item: %w", err) + } + } + + // Batch load shared data for this page. + batchData, err := batchDataFunc(ctx, pageIDs) + if err != nil { + return fmt.Errorf("failed to load batch data for "+ + "page: %w", err) + } + + // Step 3: Process each item in this page with the shared batch + // data. + for _, item := range items { + err := processItem(ctx, item, batchData) + if err != nil { + return fmt.Errorf("failed to process item "+ + "with batch data: %w", err) + } + + // Update cursor for next page. + cursor = extractPageCursor(item) + } + + // If the number of items is less than the max page size, + // we assume there are no more items to fetch. + if len(items) < int(cfg.MaxPageSize) { + break + } + } + + return nil +} diff --git a/sqldb/paginate_test.go b/sqldb/paginate_test.go index 1de3a02b8..dff0a16b1 100644 --- a/sqldb/paginate_test.go +++ b/sqldb/paginate_test.go @@ -558,3 +558,533 @@ func TestExecutePaginatedQuery(t *testing.T) { }) } } + +// TestExecuteCollectAndBatchWithSharedDataQuery tests the +// ExecuteCollectAndBatchWithSharedDataQuery function which processes items in +// pages, allowing for efficient querying and processing of large datasets with +// shared data across batches. +func TestExecuteCollectAndBatchWithSharedDataQuery(t *testing.T) { + t.Parallel() + + type channelRow struct { + id int64 + name string + policyIDs []int64 + } + + type channelBatchData struct { + lookupTable map[int64]string + sharedInfo string + } + + type processedChannel struct { + id int64 + name string + sharedInfo string + lookupData string + } + + tests := []struct { + name string + maxPageSize int32 + allRows []channelRow + initialCursor int64 + pageQueryError error + pageQueryErrorOnCall int + batchDataError error + batchDataErrorOnBatch int + processError error + processErrorOnID int64 + earlyTerminationOnPage int + expectedError string + expectedProcessedItems []processedChannel + expectedPageCalls int + expectedBatchCalls int + }{ + { + name: "multiple pages multiple batches", + maxPageSize: 2, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + policyIDs: []int64{10, 11}, + }, + { + id: 2, + name: "Chan2", + policyIDs: []int64{20}, + }, + { + id: 3, + name: "Chan3", + policyIDs: []int64{30, 31}, + }, + { + id: 4, + name: "Chan4", + policyIDs: []int64{40}, + }, + { + id: 5, + name: "Chan5", + policyIDs: []int64{50}, + }, + }, + initialCursor: 0, + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + { + id: 2, + name: "Chan2", + sharedInfo: "batch-shared", + lookupData: "lookup-2", + }, + { + id: 3, + name: "Chan3", + sharedInfo: "batch-shared", + lookupData: "lookup-3", + }, + { + id: 4, + name: "Chan4", + sharedInfo: "batch-shared", + lookupData: "lookup-4", + }, + { + id: 5, + name: "Chan5", + sharedInfo: "batch-shared", + lookupData: "lookup-5", + }, + }, + // Pages: [1,2], [3,4], [5]. + expectedPageCalls: 3, + // One batch call per page with data: [1,2], [3,4], [5]. + expectedBatchCalls: 3, + }, + { + name: "empty results", + maxPageSize: 10, + allRows: []channelRow{}, + initialCursor: 0, + // One call that returns empty. + expectedPageCalls: 1, + // No batches since no items. + expectedBatchCalls: 0, + }, + { + name: "single page single batch", + maxPageSize: 10, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + policyIDs: []int64{10}, + }, + { + id: 2, + name: "Chan2", + policyIDs: []int64{20}, + }, + }, + initialCursor: 0, + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + { + id: 2, + name: "Chan2", + sharedInfo: "batch-shared", + lookupData: "lookup-2", + }, + }, + // One page with all items. + expectedPageCalls: 1, + // One batch call for the single page. + expectedBatchCalls: 1, + }, + { + name: "page query error first call", + maxPageSize: 5, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + }, + initialCursor: 0, + pageQueryError: errors.New( + "database connection failed", + ), + pageQueryErrorOnCall: 1, + expectedError: "failed to fetch page with " + + "cursor 0", + expectedPageCalls: 1, + expectedBatchCalls: 0, + }, + { + name: "page query error second call", + maxPageSize: 1, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + pageQueryError: errors.New("database error on " + + "second page"), + pageQueryErrorOnCall: 2, + expectedError: "failed to fetch page with " + + "cursor 1", + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + }, + expectedPageCalls: 2, + expectedBatchCalls: 1, + }, + { + name: "batch data error first batch", + maxPageSize: 10, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + batchDataError: errors.New("batch loading " + + "failed"), + batchDataErrorOnBatch: 1, + expectedError: "failed to load batch data " + + "for page", + expectedPageCalls: 1, + expectedBatchCalls: 1, + }, + { + name: "batch data error second page", + maxPageSize: 1, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + batchDataError: errors.New("batch loading " + + "failed on second page"), + batchDataErrorOnBatch: 2, + expectedError: "failed to load batch data " + + "for page", + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + }, + expectedPageCalls: 2, + expectedBatchCalls: 2, + }, + { + name: "process error first item", + maxPageSize: 10, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + processError: errors.New("processing failed"), + processErrorOnID: 1, + expectedError: "failed to process item with " + + "batch data", + expectedPageCalls: 1, + expectedBatchCalls: 1, + }, + { + name: "process error second item", + maxPageSize: 10, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + processError: errors.New("processing failed"), + processErrorOnID: 2, + expectedError: "failed to process item with batch " + + "data", + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + }, + expectedPageCalls: 1, + expectedBatchCalls: 1, + }, + { + name: "early termination partial page", + maxPageSize: 3, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + }, + initialCursor: 0, + earlyTerminationOnPage: 1, + expectedProcessedItems: []processedChannel{ + { + id: 1, + name: "Chan1", + sharedInfo: "batch-shared", + lookupData: "lookup-1", + }, + { + id: 2, + name: "Chan2", + sharedInfo: "batch-shared", + lookupData: "lookup-2", + }, + }, + expectedPageCalls: 1, + expectedBatchCalls: 1, + }, + { + name: "different initial cursor", + maxPageSize: 2, + allRows: []channelRow{ + { + id: 1, + name: "Chan1", + }, + { + id: 2, + name: "Chan2", + }, + { + id: 3, + name: "Chan3", + }, + }, + initialCursor: 1, + expectedProcessedItems: []processedChannel{ + { + id: 2, + name: "Chan2", + sharedInfo: "batch-shared", + lookupData: "lookup-2", + }, + { + id: 3, + name: "Chan3", + sharedInfo: "batch-shared", + lookupData: "lookup-3", + }, + }, + // [2,3], []. + expectedPageCalls: 2, + // One batch call for the page with [2,3]. + expectedBatchCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + cfg := &QueryConfig{ + MaxPageSize: tt.maxPageSize, + } + + var ( + processedItems []processedChannel + pageCallCount int + batchCallCount int + ) + + pageQueryFunc := func(ctx context.Context, cursor int64, + limit int32) ([]channelRow, error) { + + pageCallCount++ + + // Return error on specific call if configured. + //nolint:ll + if tt.pageQueryErrorOnCall > 0 && + pageCallCount == tt.pageQueryErrorOnCall { + + return nil, tt.pageQueryError + } + + // Simulate cursor-based pagination. + var items []channelRow + for _, row := range tt.allRows { + if row.id > cursor && + len(items) < int(limit) { + + items = append(items, row) + } + } + + // Handle early termination test case. + //nolint:ll + if tt.earlyTerminationOnPage > 0 && + pageCallCount == tt.earlyTerminationOnPage { + + // Return fewer items than maxPageSize + // to trigger termination + if len(items) >= int(tt.maxPageSize) { + items = items[:tt.maxPageSize-1] + } + } + + return items, nil + } + + extractPageCursor := func(row channelRow) int64 { + return row.id + } + + collectFunc := func(row channelRow) (int64, error) { + return row.id, nil + } + + batchDataFunc := func(ctx context.Context, + ids []int64) (*channelBatchData, error) { + + batchCallCount++ + + // Return error on specific batch if configured. + //nolint:ll + if tt.batchDataErrorOnBatch > 0 && + batchCallCount == tt.batchDataErrorOnBatch { + + return nil, tt.batchDataError + } + + // Create mock batch data. + lookupTable := make(map[int64]string) + for _, id := range ids { + lookupTable[id] = + fmt.Sprintf("lookup-%d", id) + } + + return &channelBatchData{ + lookupTable: lookupTable, + sharedInfo: "batch-shared", + }, nil + } + + processItem := func(ctx context.Context, row channelRow, + batchData *channelBatchData) error { + + // Return error on specific item if configured. + if tt.processErrorOnID > 0 && + row.id == tt.processErrorOnID { + + return tt.processError + } + + processedChan := processedChannel{ + id: row.id, + name: row.name, + sharedInfo: batchData.sharedInfo, + lookupData: batchData. + lookupTable[row.id], + } + + processedItems = append( + processedItems, processedChan, + ) + return nil + } + + err := ExecuteCollectAndBatchWithSharedDataQuery( + ctx, cfg, tt.initialCursor, + pageQueryFunc, extractPageCursor, collectFunc, + batchDataFunc, processItem, + ) + + // Check error expectations. + if tt.expectedError != "" { + require.ErrorContains(t, err, tt.expectedError) + if tt.pageQueryError != nil { + require.ErrorIs( + t, err, tt.pageQueryError, + ) + } + if tt.batchDataError != nil { + require.ErrorIs( + t, err, tt.batchDataError, + ) + } + if tt.processError != nil { + require.ErrorIs( + t, err, tt.processError, + ) + } + } else { + require.NoError(t, err) + } + + // Check processed results. + require.Equal( + t, tt.expectedProcessedItems, processedItems, + ) + + // Check call counts. + require.Equal( + t, tt.expectedPageCalls, pageCallCount, + ) + require.Equal( + t, tt.expectedBatchCalls, batchCallCount, + ) + }) + } +}