From 5afd9a56784b73f72614a27c1285dc5b36ab61b1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 14 Jul 2025 12:28:09 +0200 Subject: [PATCH 1/9] scripts: add sql slices workaround to sqlc gen script This copies the workaround introduced in the taproot-assets code base and will allow us to use `WHERE x in ` type queries. --- scripts/gen_sqlc_docker.sh | 24 ++++++++++++- sqldb/sqlc/db_custom.go | 39 +++++++++++++++++++++ sqldb/sqlc/db_custom_test.go | 67 ++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 sqldb/sqlc/db_custom.go create mode 100644 sqldb/sqlc/db_custom_test.go diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh index 2520b7118..6d728fc42 100755 --- a/scripts/gen_sqlc_docker.sh +++ b/scripts/gen_sqlc_docker.sh @@ -46,4 +46,26 @@ docker run \ -e UID=$UID \ -v "$DIR/../:/build" \ -w /build \ - "sqlc/sqlc:${SQLC_VERSION}" generate \ No newline at end of file + "sqlc/sqlc:${SQLC_VERSION}" generate + +# Because we're using the Postgres dialect of sqlc, we can't use sqlc.slice() +# normally, because sqlc just thinks it can pass the Golang slice directly to +# the database driver. So it doesn't put the /*SLICE:*/ workaround +# comment into the actual SQL query. But we add the comment ourselves and now +# just need to replace the '$X/*SLICE:*/' placeholders with the +# actual placeholder that's going to be replaced by the sqlc generated code. +echo "Applying sqlc.slice() workaround..." +for file in sqldb/sqlc/*.sql.go; do + echo "Patching $file" + + # First, we replace the `$X/*SLICE:*/` placeholders with + # the actual placeholder that sqlc will use: `/*SLICE:*/?`. + sed -i.bak -E 's/\$([0-9]+)\/\*SLICE:([a-zA-Z_][a-zA-Z0-9_]*)\*\//\/\*SLICE:\2\*\/\?/g' "$file" + + # Then, we replace the `strings.Repeat(",?", len(arg.))[1:]` with + # a function call that generates the correct number of placeholders: + # `makeQueryParams(len(queryParams), len(arg.))`. + sed -i.bak -E 's/strings\.Repeat\(",\?", len\(([^)]+)\)\)\[1:\]/makeQueryParams(len(queryParams), len(\1))/g' "$file" + + rm "$file.bak" +done diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go new file mode 100644 index 000000000..2490e5feb --- /dev/null +++ b/sqldb/sqlc/db_custom.go @@ -0,0 +1,39 @@ +package sqlc + +import ( + "fmt" + "strings" +) + +// makeQueryParams generates a string of query parameters for a SQL query. It is +// meant to replace the `?` placeholders in a SQL query with numbered parameters +// like `$1`, `$2`, etc. This is required for the sqlc /*SLICE:*/ +// workaround. See scripts/gen_sqlc_docker.sh for more details. +func makeQueryParams(numTotalArgs, numListArgs int) string { + if numListArgs == 0 { + return "" + } + + var b strings.Builder + + // Pre-allocate a rough estimation of the buffer size to avoid + // re-allocations. A parameter like $1000, takes 6 bytes. + b.Grow(numListArgs * 6) + + diff := numTotalArgs - numListArgs + for i := 0; i < numListArgs; i++ { + if i > 0 { + // We don't need to check the error here because the + // WriteString method of strings.Builder always returns + // nil. + _, _ = b.WriteString(",") + } + + // We don't need to check the error here because the + // Write method (called by fmt.Fprintf) of strings.Builder + // always returns nil. + _, _ = fmt.Fprintf(&b, "$%d", i+diff+1) + } + + return b.String() +} diff --git a/sqldb/sqlc/db_custom_test.go b/sqldb/sqlc/db_custom_test.go new file mode 100644 index 000000000..7db627c88 --- /dev/null +++ b/sqldb/sqlc/db_custom_test.go @@ -0,0 +1,67 @@ +package sqlc + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// BenchmarkMakeQueryParams benchmarks the makeQueryParams function for +// various argument sizes. This helps to ensure the function performs +// efficiently when generating SQL query parameter strings for different +// input sizes. +func BenchmarkMakeQueryParams(b *testing.B) { + cases := []struct { + totalArgs int + listArgs int + }{ + {totalArgs: 5, listArgs: 2}, + {totalArgs: 10, listArgs: 3}, + {totalArgs: 50, listArgs: 10}, + {totalArgs: 100, listArgs: 20}, + } + + for _, c := range cases { + name := fmt.Sprintf( + "totalArgs=%d/listArgs=%d", c.totalArgs, + c.listArgs, + ) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = makeQueryParams( + c.totalArgs, c.listArgs, + ) + } + }) + } +} + +// TestMakeQueryParams tests the makeQueryParams function for various +// argument sizes and verifies the output matches the expected SQL +// parameter string. The function is assumed to generate a comma-separated +// list of parameters in the form $N for use in SQL queries. +func TestMakeQueryParams(t *testing.T) { + t.Parallel() + + testCases := []struct { + totalArgs int + listArgs int + expected string + }{ + {totalArgs: 5, listArgs: 2, expected: "$4,$5"}, + {totalArgs: 10, listArgs: 3, expected: "$8,$9,$10"}, + {totalArgs: 1, listArgs: 1, expected: "$1"}, + {totalArgs: 3, listArgs: 0, expected: ""}, + {totalArgs: 4, listArgs: 4, expected: "$1,$2,$3,$4"}, + } + + for _, tc := range testCases { + result := makeQueryParams(tc.totalArgs, tc.listArgs) + require.Equal( + t, tc.expected, result, + "unexpected result for totalArgs=%d, "+ + "listArgs=%d", tc.totalArgs, tc.listArgs, + ) + } +} From 006905d57fb9153349e93515c0800ee00fa590e8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 16 Jul 2025 08:27:44 +0200 Subject: [PATCH 2/9] sqldb: add ExecutePagedQuery helper Along with a test for it. This helper will allow us to easily create a pagination wrapper for queries that will make use of the new /*SLICE:*/ directive. The next commit will add a test showing this. --- sqldb/paginate.go | 76 +++++++++++++ sqldb/paginate_test.go | 236 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+) create mode 100644 sqldb/paginate.go create mode 100644 sqldb/paginate_test.go diff --git a/sqldb/paginate.go b/sqldb/paginate.go new file mode 100644 index 000000000..143b456d5 --- /dev/null +++ b/sqldb/paginate.go @@ -0,0 +1,76 @@ +package sqldb + +import ( + "context" + "fmt" +) + +// PagedQueryFunc represents a function that takes a slice of converted items +// and returns results. +type PagedQueryFunc[T any, R any] func(context.Context, []T) ([]R, error) + +// ItemCallbackFunc represents a function that processes individual results. +type ItemCallbackFunc[R any] func(context.Context, R) error + +// ConvertFunc represents a function that converts from input type to query type +type ConvertFunc[I any, T any] func(I) T + +// PagedQueryConfig holds configuration values for calls to ExecutePagedQuery. +type PagedQueryConfig struct { + PageSize int +} + +// DefaultPagedQueryConfig returns a default configuration +func DefaultPagedQueryConfig() *PagedQueryConfig { + return &PagedQueryConfig{ + PageSize: 1000, + } +} + +// ExecutePagedQuery executes a paginated query over a slice of input items. +// It converts the input items to a query type using the provided convertFunc, +// executes the query using the provided queryFunc, and applies the callback +// to each result. +func ExecutePagedQuery[I any, T any, R any](ctx context.Context, + cfg *PagedQueryConfig, inputItems []I, convertFunc ConvertFunc[I, T], + queryFunc PagedQueryFunc[T, R], callback ItemCallbackFunc[R]) error { + + if len(inputItems) == 0 { + return nil + } + + // Process items in pages. + for i := 0; i < len(inputItems); i += cfg.PageSize { + // Calculate the end index for this page. + end := i + cfg.PageSize + if end > len(inputItems) { + end = len(inputItems) + } + + // Get the page slice of input items. + inputPage := inputItems[i:end] + + // Convert only the items needed for this page. + convertedPage := make([]T, len(inputPage)) + for j, inputItem := range inputPage { + convertedPage[j] = convertFunc(inputItem) + } + + // Execute the query for this page. + results, err := queryFunc(ctx, convertedPage) + if err != nil { + return fmt.Errorf("query failed for page "+ + "starting at %d: %w", i, err) + } + + // Apply the callback to each result. + for _, result := range results { + if err := callback(ctx, result); err != nil { + return fmt.Errorf("callback failed for "+ + "result: %w", err) + } + } + } + + return nil +} diff --git a/sqldb/paginate_test.go b/sqldb/paginate_test.go new file mode 100644 index 000000000..33b9cd93e --- /dev/null +++ b/sqldb/paginate_test.go @@ -0,0 +1,236 @@ +package sqldb + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestExecutePagedQuery tests the ExecutePagedQuery function which processes +// items in pages, allowing for efficient querying and processing of large +// datasets. +func TestExecutePagedQuery(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("empty input returns nil", func(t *testing.T) { + var ( + cfg = DefaultPagedQueryConfig() + inputItems []int + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("%d", i) + } + + queryFunc := func(ctx context.Context, items []string) ( + []string, error) { + + require.Fail(t, "queryFunc should not be called "+ + "with empty input") + return nil, nil + } + callback := func(ctx context.Context, result string) error { + require.Fail(t, "callback should not be called with "+ + "empty input") + + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.NoError(t, err) + }) + + t.Run("single page processes all items", func(t *testing.T) { + var ( + convertedItems []string + callbackResults []string + inputItems = []int{1, 2, 3, 4, 5} + cfg = &PagedQueryConfig{ + PageSize: 10, + } + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("converted_%d", i) + } + + queryFunc := func(ctx context.Context, + items []string) ([]string, error) { + + convertedItems = append(convertedItems, items...) + results := make([]string, len(items)) + for i, item := range items { + results[i] = fmt.Sprintf("result_%s", item) + } + + return results, nil + } + + callback := func(ctx context.Context, result string) error { + callbackResults = append(callbackResults, result) + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.NoError(t, err) + + require.Equal(t, []string{ + "converted_1", "converted_2", "converted_3", + "converted_4", "converted_5", + }, convertedItems) + + require.Equal(t, []string{ + "result_converted_1", "result_converted_2", + "result_converted_3", "result_converted_4", + "result_converted_5", + }, callbackResults) + }) + + t.Run("multiple pages process correctly", func(t *testing.T) { + var ( + queryCallCount int + pageSizes []int + allResults []string + inputItems = []int{1, 2, 3, 4, 5, 6, 7, 8} + cfg = &PagedQueryConfig{ + PageSize: 3, + } + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("item_%d", i) + } + + queryFunc := func(ctx context.Context, + items []string) ([]string, error) { + + queryCallCount++ + pageSizes = append(pageSizes, len(items)) + results := make([]string, len(items)) + for i, item := range items { + results[i] = fmt.Sprintf("result_%s", item) + } + + return results, nil + } + + callback := func(ctx context.Context, result string) error { + allResults = append(allResults, result) + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.NoError(t, err) + + // Should have 3 pages: [1,2,3], [4,5,6], [7,8] + require.Equal(t, 3, queryCallCount) + require.Equal(t, []int{3, 3, 2}, pageSizes) + require.Len(t, allResults, 8) + }) + + t.Run("query function error is propagated", func(t *testing.T) { + var ( + cfg = DefaultPagedQueryConfig() + inputItems = []int{1, 2, 3} + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("%d", i) + } + + queryFunc := func(ctx context.Context, + items []string) ([]string, error) { + + return nil, errors.New("query failed") + } + + callback := func(ctx context.Context, result string) error { + require.Fail(t, "callback should not be called when "+ + "query fails") + + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.ErrorContains(t, err, "query failed for page "+ + "starting at 0: query failed") + }) + + t.Run("callback error is propagated", func(t *testing.T) { + var ( + cfg = DefaultPagedQueryConfig() + inputItems = []int{1, 2, 3} + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("%d", i) + } + + queryFunc := func(ctx context.Context, + items []string) ([]string, error) { + + return items, nil + } + + callback := func(ctx context.Context, result string) error { + if result == "2" { + return errors.New("callback failed") + } + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.ErrorContains(t, err, "callback failed for result: "+ + "callback failed") + }) + + t.Run("query error in second page is propagated", func(t *testing.T) { + var ( + inputItems = []int{1, 2, 3, 4} + cfg = &PagedQueryConfig{ + PageSize: 2, + } + queryCallCount int + ) + + convertFunc := func(i int) string { + return fmt.Sprintf("%d", i) + } + + queryFunc := func(ctx context.Context, + items []string) ([]string, error) { + + queryCallCount++ + if queryCallCount == 2 { + return nil, fmt.Errorf("second page failed") + } + + return items, nil + } + + callback := func(ctx context.Context, result string) error { + return nil + } + + err := ExecutePagedQuery( + ctx, cfg, inputItems, convertFunc, queryFunc, callback, + ) + require.ErrorContains(t, err, "query failed for page "+ + "starting at 2: second page failed") + }) +} From f0d2d1fd0ac2ad190bbefbff162aca891195c692 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 16 Jul 2025 08:29:40 +0200 Subject: [PATCH 3/9] sqldb: demonstrate the use of ExecutePagedQuery Here, a new query (GetChannelsByOutpoints) is added which makes use of the /*SLICE:outpoints*/ directive & added workaround. This is then used in a test to demonstrate how the ExecutePagedQuery helper can be used to wrap a query like this such that calls are done in pages. The query that has been added will also be used by live code paths in an upcoming commit. --- sqldb/paginate_test.go | 79 ++++++++++++++++++++++++++++++++++++ sqldb/postgres_test.go | 8 ++++ sqldb/sqlc/graph.sql.go | 68 +++++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 11 +++++ sqldb/sqlite_test.go | 8 ++++ 6 files changed, 175 insertions(+) diff --git a/sqldb/paginate_test.go b/sqldb/paginate_test.go index 33b9cd93e..0ebf25371 100644 --- a/sqldb/paginate_test.go +++ b/sqldb/paginate_test.go @@ -1,3 +1,5 @@ +//go:build test_db_postgres || test_db_sqlite + package sqldb import ( @@ -6,6 +8,7 @@ import ( "fmt" "testing" + "github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/stretchr/testify/require" ) @@ -234,3 +237,79 @@ func TestExecutePagedQuery(t *testing.T) { "starting at 2: second page failed") }) } + +// TestSQLSliceQueries tests ExecutePageQuery helper by first showing that a +// query the /*SLICE:*/ directive has a maximum number of +// parameters it can handle, and then showing that the paginated version which +// uses ExecutePagedQuery instead of a raw query can handle more parameters by +// executing the query in pages. +func TestSQLSliceQueries(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := NewTestDB(t) + + // Increase the number of query strings by an order of magnitude each + // iteration until we hit the limit of the backing DB. + // + // NOTE: from testing, the following limits have been noted: + // - for Postgres, the limit is 65535 parameters. + // - for SQLite, the limit is 32766 parameters. + x := 10 + var queryParams []string + for { + for len(queryParams) < x { + queryParams = append( + queryParams, + fmt.Sprintf("%d", len(queryParams)), + ) + } + + _, err := db.GetChannelsByOutpoints(ctx, queryParams) + if err != nil { + if isSQLite { + require.ErrorContains( + t, err, "SQL logic error: too many "+ + "SQL variables", + ) + } else { + require.ErrorContains( + t, err, "extended protocol limited "+ + "to 65535 parameters", + ) + } + break + } + + x *= 10 + + // Just to make sure that the test doesn't carry on too long, + // we assert that we don't exceed a reasonable limit. + require.LessOrEqual(t, x, 100000) + } + + // Now that we have found the limit that the raw query can handle, we + // switch to the wrapped version which will perform the query in pages + // so that the limit is not hit. We use the same number of query params + // that caused the error above. + queryWrapper := func(ctx context.Context, + pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow, + error) { + + return db.GetChannelsByOutpoints(ctx, pageOutpoints) + } + + err := ExecutePagedQuery( + ctx, + DefaultPagedQueryConfig(), + queryParams, + func(s string) string { + return s + }, + queryWrapper, + func(context.Context, sqlc.GetChannelsByOutpointsRow) error { + return nil + }, + ) + require.NoError(t, err) +} diff --git a/sqldb/postgres_test.go b/sqldb/postgres_test.go index 22a29f885..cbeb8ca68 100644 --- a/sqldb/postgres_test.go +++ b/sqldb/postgres_test.go @@ -7,6 +7,14 @@ import ( "testing" ) +// isSQLite is false if the build tag is set to test_db_postgres. It is used in +// tests that compile for both SQLite and Postgres databases to determine +// which database implementation is being used. +// +// TODO(elle): once we've updated to using sqldbv2, we can remove this since +// then we will have access to the DatabaseType on the BaseDB struct at runtime. +const isSQLite = false + // NewTestDB is a helper function that creates a Postgres database for testing. func NewTestDB(t *testing.T) *PostgresStore { pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 79c393880..4e61aa7eb 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -8,6 +8,7 @@ package sqlc import ( "context" "database/sql" + "strings" ) const addSourceNode = `-- name: AddSourceNode :exec @@ -881,6 +882,73 @@ func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannel return items, nil } +const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id +WHERE c.outpoint IN + (/*SLICE:outpoints*/?) +` + +type GetChannelsByOutpointsRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte +} + +func (q *Queries) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) { + query := getChannelsByOutpoints + var queryParams []interface{} + if len(outpoints) > 0 { + for _, v := range outpoints { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:outpoints*/?", makeQueryParams(len(queryParams), len(outpoints)), 1) + } else { + query = strings.Replace(query, "/*SLICE:outpoints*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByOutpointsRow + for rows.Next() { + var i GetChannelsByOutpointsRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.Node1Pubkey, + &i.Node2Pubkey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index ba84db5ac..8155b3de8 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -42,6 +42,7 @@ type Querier interface { GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (GraphChannelPolicy, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) + GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5088b37ee..120d06de8 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -231,6 +231,17 @@ WHERE scid >= @start_scid SELECT * FROM graph_channels WHERE scid = $1 AND version = $2; +-- name: GetChannelsByOutpoints :many +SELECT + sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id +WHERE c.outpoint IN + (sqlc.slice('outpoints')/*SLICE:outpoints*/); + -- name: GetChannelByOutpoint :one SELECT sqlc.embed(c), diff --git a/sqldb/sqlite_test.go b/sqldb/sqlite_test.go index 9dfb875ea..87fdaea04 100644 --- a/sqldb/sqlite_test.go +++ b/sqldb/sqlite_test.go @@ -7,6 +7,14 @@ import ( "testing" ) +// isSQLite is true if the build tag is set to test_db_sqlite. It is used in +// tests that compile for both SQLite and Postgres databases to determine +// which database implementation is being used. +// +// TODO(elle): once we've updated to using sqldbv2, we can remove this since +// then we will have access to the DatabaseType on the BaseDB struct at runtime. +const isSQLite = true + // NewTestDB is a helper function that creates an SQLite database for testing. func NewTestDB(t *testing.T) *SqliteStore { return NewTestSqliteDB(t) From 2fa30e87355acfe0a413a9d3b438c3b396fa591e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 16 Jul 2025 08:33:02 +0200 Subject: [PATCH 4/9] graph+config: add sql pagination config to ChannelGraph --- config_test_native_sql.go | 3 ++- graph/db/sql_store.go | 3 +++ graph/db/test_postgres.go | 11 ++++++++--- graph/db/test_sqlite.go | 3 ++- itest/lnd_graph_migration_test.go | 3 ++- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/config_test_native_sql.go b/config_test_native_sql.go index ff1569aa0..b9b4debd0 100644 --- a/config_test_native_sql.go +++ b/config_test_native_sql.go @@ -32,7 +32,8 @@ func (d *DefaultDatabaseBuilder) getGraphStore(baseDB *sqldb.BaseDB, return graphdb.NewSQLStore( &graphdb.SQLStoreConfig{ - ChainHash: *d.cfg.ActiveNetParams.GenesisHash, + ChainHash: *d.cfg.ActiveNetParams.GenesisHash, + PaginationCfg: sqldb.DefaultPagedQueryConfig(), }, graphExecutor, opts..., ) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index d9e6c650e..e61609733 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -182,6 +182,9 @@ type SQLStoreConfig struct { // ChainHash is the genesis hash for the chain that all the gossip // messages in this store are aimed at. ChainHash chainhash.Hash + + // PaginationCfg is the configuration for paginated queries. + PaginationCfg *sqldb.PagedQueryConfig } // NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries diff --git a/graph/db/test_postgres.go b/graph/db/test_postgres.go index 5e10d94cc..a812f998f 100644 --- a/graph/db/test_postgres.go +++ b/graph/db/test_postgres.go @@ -19,7 +19,9 @@ func NewTestDB(t testing.TB) V1Store { // NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes. func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { - pgFixture := sqldb.NewTestPgFixture(t, sqldb.DefaultPostgresFixtureLifetime) + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) t.Cleanup(func() { pgFixture.TearDown(t) }) @@ -28,7 +30,9 @@ func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { // NewTestDBWithFixture is a helper function that creates a SQLStore backed by a // SQL database for testing. -func NewTestDBWithFixture(t testing.TB, pgFixture *sqldb.TestPgFixture) V1Store { +func NewTestDBWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture) V1Store { + var querier BatchedSQLQueries if pgFixture == nil { querier = newBatchQuerier(t) @@ -38,7 +42,8 @@ func NewTestDBWithFixture(t testing.TB, pgFixture *sqldb.TestPgFixture) V1Store store, err := NewSQLStore( &SQLStoreConfig{ - ChainHash: *chaincfg.MainNetParams.GenesisHash, + ChainHash: *chaincfg.MainNetParams.GenesisHash, + PaginationCfg: sqldb.DefaultPagedQueryConfig(), }, querier, ) require.NoError(t, err) diff --git a/graph/db/test_sqlite.go b/graph/db/test_sqlite.go index 4d52b00ba..41773c66e 100644 --- a/graph/db/test_sqlite.go +++ b/graph/db/test_sqlite.go @@ -27,7 +27,8 @@ func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture) V1Store { store, err := NewSQLStore( &SQLStoreConfig{ - ChainHash: *chaincfg.MainNetParams.GenesisHash, + ChainHash: *chaincfg.MainNetParams.GenesisHash, + PaginationCfg: sqldb.DefaultPagedQueryConfig(), }, newBatchQuerier(t), ) require.NoError(t, err) diff --git a/itest/lnd_graph_migration_test.go b/itest/lnd_graph_migration_test.go index d7417abe4..04745da49 100644 --- a/itest/lnd_graph_migration_test.go +++ b/itest/lnd_graph_migration_test.go @@ -144,7 +144,8 @@ func openNativeSQLGraphDB(ht *lntest.HarnessTest, store, err := graphdb.NewSQLStore( &graphdb.SQLStoreConfig{ - ChainHash: *ht.Miner().ActiveNet.GenesisHash, + ChainHash: *ht.Miner().ActiveNet.GenesisHash, + PaginationCfg: sqldb.DefaultPagedQueryConfig(), }, executor, ) From f72c48b283d13f97af172d2f24dd45ad3237589b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 16 Jul 2025 08:33:29 +0200 Subject: [PATCH 5/9] graph/db+sqldb: pass set of outpoints to SQL This commit adds a new GetChannelsByOutpoints query which takes a slice of outpoint strings. This lets us then update PruneGraph to use paginated calls to GetChannelsByOutpoints instead of making one DB call per outpoint. --- graph/db/sql_store.go | 62 +++++++++++++++++++++++++----------- sqldb/sqlc/graph.sql.go | 40 ----------------------- sqldb/sqlc/querier.go | 1 - sqldb/sqlc/queries/graph.sql | 10 ------ 4 files changed, 44 insertions(+), 69 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index e61609733..cc8b199bc 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -93,7 +93,7 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error) - GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error) + GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error) GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) @@ -2365,22 +2365,9 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, prunedNodes []route.Vertex ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - for _, outpoint := range spentOutputs { - // TODO(elle): potentially optimize this by using - // sqlc.slice() once that works for both SQLite and - // Postgres. - // - // NOTE: this fetches channels for all protocol - // versions. - row, err := db.GetChannelByOutpoint( - ctx, outpoint.String(), - ) - if errors.Is(err, sql.ErrNoRows) { - continue - } else if err != nil { - return fmt.Errorf("unable to fetch channel: %w", - err) - } + // Define the callback function for processing each channel. + channelCallback := func(ctx context.Context, + row sqlc.GetChannelsByOutpointsRow) error { node1, node2, err := buildNodeVertices( row.Node1Pubkey, row.Node2Pubkey, @@ -2404,9 +2391,19 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, } closedChans = append(closedChans, info) + + return nil } - err := db.UpsertPruneLogEntry( + err := s.forEachChanInOutpoints( + ctx, db, spentOutputs, channelCallback, + ) + if err != nil { + return fmt.Errorf("unable to fetch channels by "+ + "outpoints: %w", err) + } + + err = db.UpsertPruneLogEntry( ctx, sqlc.UpsertPruneLogEntryParams{ BlockHash: blockHash[:], BlockHeight: int64(blockHeight), @@ -2442,6 +2439,35 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, return closedChans, prunedNodes, nil } +// forEachChanInOutpoints is a helper function that executes a paginated +// query to fetch channels by their outpoints and applies the given call-back +// to each. +// +// NOTE: this fetches channels for all protocol versions. +func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries, + outpoints []*wire.OutPoint, cb func(ctx context.Context, + row sqlc.GetChannelsByOutpointsRow) error) error { + + // Create a wrapper that uses the transaction's db instance to execute + // the query. + queryWrapper := func(ctx context.Context, + pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow, + error) { + + return db.GetChannelsByOutpoints(ctx, pageOutpoints) + } + + // Define the conversion function from Outpoint to string. + outpointToString := func(outpoint *wire.OutPoint) string { + return outpoint.String() + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, outpoints, outpointToString, + queryWrapper, cb, + ) +} + // ChannelView returns the verifiable edge information for each active channel // within the known channel graph. The set of UTXOs (along with their scripts) // returned are the ones that need to be watched on chain to detect channel diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 4e61aa7eb..755aea5f3 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -358,46 +358,6 @@ func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAn return i, err } -const getChannelByOutpoint = `-- name: GetChannelByOutpoint :one -SELECT - c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, - n1.pub_key AS node1_pubkey, - n2.pub_key AS node2_pubkey -FROM graph_channels c - JOIN graph_nodes n1 ON c.node_id_1 = n1.id - JOIN graph_nodes n2 ON c.node_id_2 = n2.id -WHERE c.outpoint = $1 -` - -type GetChannelByOutpointRow struct { - GraphChannel GraphChannel - Node1Pubkey []byte - Node2Pubkey []byte -} - -func (q *Queries) GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error) { - row := q.db.QueryRowContext(ctx, getChannelByOutpoint, outpoint) - var i GetChannelByOutpointRow - err := row.Scan( - &i.GraphChannel.ID, - &i.GraphChannel.Version, - &i.GraphChannel.Scid, - &i.GraphChannel.NodeID1, - &i.GraphChannel.NodeID2, - &i.GraphChannel.Outpoint, - &i.GraphChannel.Capacity, - &i.GraphChannel.BitcoinKey1, - &i.GraphChannel.BitcoinKey2, - &i.GraphChannel.Node1Signature, - &i.GraphChannel.Node2Signature, - &i.GraphChannel.Bitcoin1Signature, - &i.GraphChannel.Bitcoin2Signature, - &i.Node1Pubkey, - &i.Node2Pubkey, - ) - return i, err -} - const getChannelByOutpointWithPolicies = `-- name: GetChannelByOutpointWithPolicies :one SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 8155b3de8..786168bb5 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -35,7 +35,6 @@ type Querier interface { FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) - GetChannelByOutpoint(ctx context.Context, outpoint string) (GetChannelByOutpointRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (GraphChannel, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 120d06de8..f86197749 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -242,16 +242,6 @@ FROM graph_channels c WHERE c.outpoint IN (sqlc.slice('outpoints')/*SLICE:outpoints*/); --- name: GetChannelByOutpoint :one -SELECT - sqlc.embed(c), - n1.pub_key AS node1_pubkey, - n2.pub_key AS node2_pubkey -FROM graph_channels c - JOIN graph_nodes n1 ON c.node_id_1 = n1.id - JOIN graph_nodes n2 ON c.node_id_2 = n2.id -WHERE c.outpoint = $1; - -- name: GetChannelAndNodesBySCID :one SELECT c.*, From 88e9a21d6308a74bae6f82699b4ab051923ff009 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 15 Jul 2025 15:30:45 +0200 Subject: [PATCH 6/9] sqldb+graph/db: update FilterKnownChanIDs to use pagination Remove a TODO by making use of the new sqldb.ExecutePagedQuery to fetch channels in batches rather than one by one. --- graph/db/sql_store.go | 87 +++++++++++++++++++++++++++++------- sqldb/sqlc/graph.sql.go | 59 ++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 5 +++ 4 files changed, 136 insertions(+), 16 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index cc8b199bc..5e9305489 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -93,6 +93,7 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error) + GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error) GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) @@ -2259,31 +2260,49 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, ctx = context.TODO() newChanIDs []uint64 knownZombies []ChannelUpdateInfo + infoLookup = make( + map[uint64]ChannelUpdateInfo, len(chansInfo), + ) ) + + // We first build a lookup map of the channel ID's to the + // ChannelUpdateInfo. This allows us to quickly delete channels that we + // already know about. + for _, chanInfo := range chansInfo { + infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo + } + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // The call-back function deletes known channels from + // infoLookup, so that we can later check which channels are + // zombies by only looking at the remaining channels in the set. + cb := func(ctx context.Context, + channel sqlc.GraphChannel) error { + + delete(infoLookup, byteOrder.Uint64(channel.Scid)) + + return nil + } + + err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo) + if err != nil { + return fmt.Errorf("unable to iterate through "+ + "channels: %w", err) + } + + // We want to ensure that we deal with the channels in the + // same order that they were passed in, so we iterate over the + // original chansInfo slice and then check if that channel is + // still in the infoLookup map. for _, chanInfo := range chansInfo { channelID := chanInfo.ShortChannelID.ToUint64() - chanIDB := channelIDToBytes(channelID) - - // TODO(elle): potentially optimize this by using - // sqlc.slice() once that works for both SQLite and - // Postgres. - _, err := db.GetChannelBySCID( - ctx, sqlc.GetChannelBySCIDParams{ - Version: int16(ProtocolV1), - Scid: chanIDB, - }, - ) - if err == nil { + if _, ok := infoLookup[channelID]; !ok { continue - } else if !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("unable to fetch channel: %w", - err) } isZombie, err := db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ - Scid: chanIDB, + Scid: channelIDToBytes(channelID), Version: int16(ProtocolV1), }, ) @@ -2305,6 +2324,11 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, }, func() { newChanIDs = nil knownZombies = nil + // Rebuild the infoLookup map in case of a rollback. + for _, chanInfo := range chansInfo { + scid := chanInfo.ShortChannelID.ToUint64() + infoLookup[scid] = chanInfo + } }) if err != nil { return nil, nil, fmt.Errorf("unable to fetch channels: %w", err) @@ -2313,6 +2337,37 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, return newChanIDs, knownZombies, nil } +// forEachChanInSCIDList is a helper method that executes a paged query +// against the database to fetch all channels that match the passed +// ChannelUpdateInfo slice. The callback function is called for each channel +// that is found. +func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, + cb func(ctx context.Context, channel sqlc.GraphChannel) error, + chansInfo []ChannelUpdateInfo) error { + + queryWrapper := func(ctx context.Context, + scids [][]byte) ([]sqlc.GraphChannel, error) { + + return db.GetChannelsBySCIDs( + ctx, sqlc.GetChannelsBySCIDsParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + } + + chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte { + channelID := chanInfo.ShortChannelID.ToUint64() + + return channelIDToBytes(channelID) + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter, + queryWrapper, cb, + ) +} + // PruneGraphNodes is a garbage collection method which attempts to prune out // any nodes from the channel graph that are currently unconnected. This ensure // that we only maintain a graph of reachable nodes. In the event that a pruned diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 755aea5f3..09df90264 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1154,6 +1154,65 @@ func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsByS return items, nil } +const getChannelsBySCIDs = `-- name: GetChannelsBySCIDs :many +SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM graph_channels +WHERE version = $1 + AND scid IN (/*SLICE:scids*/?) +` + +type GetChannelsBySCIDsParams struct { + Version int16 + Scids [][]byte +} + +func (q *Queries) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) { + query := getChannelsBySCIDs + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphChannel + for rows.Next() { + var i GraphChannel + if err := rows.Scan( + &i.ID, + &i.Version, + &i.Scid, + &i.NodeID1, + &i.NodeID2, + &i.Outpoint, + &i.Capacity, + &i.BitcoinKey1, + &i.BitcoinKey2, + &i.Node1Signature, + &i.Node2Signature, + &i.Bitcoin1Signature, + &i.Bitcoin2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM graph_node_extra_types diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 786168bb5..01d45734a 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -44,6 +44,7 @@ type Querier interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) + GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) // This method may return more than one invoice if filter using multiple fields diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index f86197749..5216ddea5 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -231,6 +231,11 @@ WHERE scid >= @start_scid SELECT * FROM graph_channels WHERE scid = $1 AND version = $2; +-- name: GetChannelsBySCIDs :many +SELECT * FROM graph_channels +WHERE version = @version + AND scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: GetChannelsByOutpoints :many SELECT sqlc.embed(c), From e269d57ffa0f55198ab5772425b1d98a5c41f918 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 15 Jul 2025 16:17:53 +0200 Subject: [PATCH 7/9] sqldb+graph/db: use pagination for FetchChanInfos --- graph/db/sql_store.go | 117 +++++++++++++++++----- sqldb/sqlc/graph.sql.go | 185 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 51 ++++++++++ 4 files changed, 330 insertions(+), 24 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 5e9305489..a782f4b45 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -97,6 +97,7 @@ type SQLQueries interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error) GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) + GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) @@ -2170,27 +2171,11 @@ func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) { func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { var ( ctx = context.TODO() - edges []ChannelEdge + edges = make(map[uint64]ChannelEdge) ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - for _, chanID := range chanIDs { - chanIDB := channelIDToBytes(chanID) - - // TODO(elle): potentially optimize this by using - // sqlc.slice() once that works for both SQLite and - // Postgres. - row, err := db.GetChannelBySCIDWithPolicies( - ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB, - Version: int16(ProtocolV1), - }, - ) - if errors.Is(err, sql.ErrNoRows) { - continue - } else if err != nil { - return fmt.Errorf("unable to fetch channel: %w", - err) - } + chanCallBack := func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error { node1, node2, err := buildNodes( ctx, db, row.GraphNode, row.GraphNode_2, @@ -2225,24 +2210,64 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { "policies: %w", err) } - edges = append(edges, ChannelEdge{ + edges[edge.ChannelID] = ChannelEdge{ Info: edge, Policy1: p1, Policy2: p2, Node1: node1, Node2: node2, - }) + } + + return nil } - return nil + return s.forEachChanWithPoliciesInSCIDList( + ctx, db, chanCallBack, chanIDs, + ) }, func() { - edges = nil + clear(edges) }) if err != nil { return nil, fmt.Errorf("unable to fetch channels: %w", err) } - return edges, nil + res := make([]ChannelEdge, 0, len(edges)) + for _, chanID := range chanIDs { + edge, ok := edges[chanID] + if !ok { + continue + } + + res = append(res, edge) + } + + return res, nil +} + +// forEachChanWithPoliciesInSCIDList is a wrapper around the +// GetChannelsBySCIDWithPolicies query that allows us to iterate through +// channels in a paginated manner. +func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context, + db SQLQueries, cb func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error, + chanIDs []uint64) error { + + queryWrapper := func(ctx context.Context, + scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, + error) { + + return db.GetChannelsBySCIDWithPolicies( + ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes, + queryWrapper, cb, + ) } // FilterKnownChanIDs takes a set of channel IDs and return the subset of chan @@ -4300,6 +4325,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, var policy1, policy2 *sqlc.GraphChannelPolicy switch r := row.(type) { + case sqlc.GetChannelsBySCIDWithPoliciesRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.GetChannelByOutpointWithPoliciesRow: if r.Policy1ID.Valid { policy1 = &sqlc.GraphChannelPolicy{ diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 09df90264..2f0ba8d67 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1154,6 +1154,191 @@ func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsByS return items, nil } +const getChannelsBySCIDWithPolicies = `-- name: GetChannelsBySCIDWithPolicies :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE + c.version = $1 + AND c.scid IN (/*SLICE:scids*/?) +` + +type GetChannelsBySCIDWithPoliciesParams struct { + Version int16 + Scids [][]byte +} + +type GetChannelsBySCIDWithPoliciesRow struct { + GraphChannel GraphChannel + GraphNode GraphNode + GraphNode_2 GraphNode + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte +} + +func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) { + query := getChannelsBySCIDWithPolicies + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsBySCIDWithPoliciesRow + for rows.Next() { + var i GetChannelsBySCIDWithPoliciesRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode_2.ID, + &i.GraphNode_2.Version, + &i.GraphNode_2.PubKey, + &i.GraphNode_2.Alias, + &i.GraphNode_2.LastUpdate, + &i.GraphNode_2.Color, + &i.GraphNode_2.Signature, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelsBySCIDs = `-- name: GetChannelsBySCIDs :many SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM graph_channels WHERE version = $1 diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 01d45734a..cd32dc75b 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -44,6 +44,7 @@ type Querier interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) + GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5216ddea5..52c09e23d 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -283,6 +283,57 @@ WHERE cet.channel_id = $1; SELECT scid from graph_channels WHERE outpoint = $1 AND version = $2; +-- name: GetChannelsBySCIDWithPolicies :many +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE + c.version = @version + AND c.scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: GetChannelsByPolicyLastUpdateRange :many SELECT sqlc.embed(c), From de6c030f29b442f8ea6745c4707bb764b995ed15 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 18 Jul 2025 11:26:22 +0200 Subject: [PATCH 8/9] graph/db: let DeleteChannelEdges use new wrapped SQL call Update it to use the new wrapped version of GetChannelsBySCIDWithPolicies to reduce the number of DB calls. --- graph/db/sql_store.go | 49 +++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index a782f4b45..1c02a7b89 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1713,26 +1713,25 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, s.cacheMu.Lock() defer s.cacheMu.Unlock() + // Keep track of which channels we end up finding so that we can + // correctly return ErrEdgeNotFound if we do not find a channel. + chanLookup := make(map[uint64]struct{}, len(chanIDs)) + for _, chanID := range chanIDs { + chanLookup[chanID] = struct{}{} + } + var ( ctx = context.TODO() deleted []*models.ChannelEdgeInfo ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - for _, chanID := range chanIDs { - chanIDB := channelIDToBytes(chanID) + chanCallBack := func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error { - row, err := db.GetChannelBySCIDWithPolicies( - ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB, - Version: int16(ProtocolV1), - }, - ) - if errors.Is(err, sql.ErrNoRows) { - return ErrEdgeNotFound - } else if err != nil { - return fmt.Errorf("unable to fetch channel: %w", - err) - } + // Deleting the entry from the map indicates that we + // have found the channel. + scid := byteOrder.Uint64(row.GraphChannel.Scid) + delete(chanLookup, scid) node1, node2, err := buildNodeVertices( row.GraphNode.PubKey, row.GraphNode_2.PubKey, @@ -1758,7 +1757,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, deleted = append(deleted, info) if !markZombie { - continue + return nil } nodeKey1, nodeKey2 := info.NodeKey1Bytes, @@ -1786,7 +1785,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, err = db.UpsertZombieChannel( ctx, sqlc.UpsertZombieChannelParams{ Version: int16(ProtocolV1), - Scid: chanIDB, + Scid: channelIDToBytes(scid), NodeKey1: nodeKey1[:], NodeKey2: nodeKey2[:], }, @@ -1795,11 +1794,29 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return fmt.Errorf("unable to mark channel as "+ "zombie: %w", err) } + + return nil + } + + err := s.forEachChanWithPoliciesInSCIDList( + ctx, db, chanCallBack, chanIDs, + ) + if err != nil { + return err + } + + if len(chanLookup) > 0 { + return ErrEdgeNotFound } return nil }, func() { deleted = nil + + // Re-fill the lookup map. + for _, chanID := range chanIDs { + chanLookup[chanID] = struct{}{} + } }) if err != nil { return nil, fmt.Errorf("unable to delete channel edges: %w", From ddc0e95edac64ecb7167bce0398ac2db7e86b8fc Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 18 Jul 2025 11:47:25 +0200 Subject: [PATCH 9/9] graph/db+sqldb: delete channels in batches Use the new `SLICES` directive to add a DeleteChannels query which takes a set of DB channel IDs. Then replace all our calls to DeleteChannel with a paginated call to DeleteChannels. --- graph/db/sql_store.go | 66 ++++++++++++++++++++++++------------ sqldb/sqlc/graph.sql.go | 29 +++++++++++----- sqldb/sqlc/querier.go | 2 +- sqldb/sqlc/queries/graph.sql | 5 +-- 4 files changed, 69 insertions(+), 33 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 1c02a7b89..081154f0f 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -108,7 +108,7 @@ type SQLQueries interface { GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error) GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error) - DeleteChannel(ctx context.Context, id int64) error + DeleteChannels(ctx context.Context, ids []int64) error CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error @@ -1725,6 +1725,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, deleted []*models.ChannelEdgeInfo ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + chanIDsToDelete := make([]int64, 0, len(chanIDs)) chanCallBack := func(ctx context.Context, row sqlc.GetChannelsBySCIDWithPoliciesRow) error { @@ -1748,13 +1749,10 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return err } - err = db.DeleteChannel(ctx, row.GraphChannel.ID) - if err != nil { - return fmt.Errorf("unable to delete "+ - "channel: %w", err) - } - deleted = append(deleted, info) + chanIDsToDelete = append( + chanIDsToDelete, row.GraphChannel.ID, + ) if !markZombie { return nil @@ -1809,7 +1807,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return ErrEdgeNotFound } - return nil + return s.deleteChannels(ctx, db, chanIDsToDelete) }, func() { deleted = nil @@ -2462,6 +2460,8 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, prunedNodes []route.Vertex ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + var chansToDelete []int64 + // Define the callback function for processing each channel. channelCallback := func(ctx context.Context, row sqlc.GetChannelsByOutpointsRow) error { @@ -2481,13 +2481,10 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, return err } - err = db.DeleteChannel(ctx, row.GraphChannel.ID) - if err != nil { - return fmt.Errorf("unable to delete "+ - "channel: %w", err) - } - closedChans = append(closedChans, info) + chansToDelete = append( + chansToDelete, row.GraphChannel.ID, + ) return nil } @@ -2500,6 +2497,11 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, "outpoints: %w", err) } + err = s.deleteChannels(ctx, db, chansToDelete) + if err != nil { + return fmt.Errorf("unable to delete channels: %w", err) + } + err = db.UpsertPruneLogEntry( ctx, sqlc.UpsertPruneLogEntryParams{ BlockHash: blockHash[:], @@ -2565,6 +2567,27 @@ func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries, ) } +func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries, + dbIDs []int64) error { + + // Create a wrapper that uses the transaction's db instance to execute + // the query. + queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) { + return nil, db.DeleteChannels(ctx, ids) + } + + idConverter := func(id int64) int64 { + return id + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, dbIDs, idConverter, + queryWrapper, func(ctx context.Context, _ any) error { + return nil + }, + ) +} + // ChannelView returns the verifiable edge information for each active channel // within the known channel graph. The set of UTXOs (along with their scripts) // returned are the ones that need to be watched on chain to detect channel @@ -2740,7 +2763,8 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( return fmt.Errorf("unable to fetch channels: %w", err) } - for _, row := range rows { + chanIDsToDelete := make([]int64, len(rows)) + for i, row := range rows { node1, node2, err := buildNodeVertices( row.Node1PubKey, row.Node2PubKey, ) @@ -2756,15 +2780,15 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( return err } - err = db.DeleteChannel(ctx, row.GraphChannel.ID) - if err != nil { - return fmt.Errorf("unable to delete "+ - "channel: %w", err) - } - + chanIDsToDelete[i] = row.GraphChannel.ID removedChans = append(removedChans, channel) } + err = s.deleteChannels(ctx, db, chanIDsToDelete) + if err != nil { + return fmt.Errorf("unable to delete channels: %w", err) + } + return db.DeletePruneLogEntriesInRange( ctx, sqlc.DeletePruneLogEntriesInRangeParams{ StartHeight: int64(height), diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 2f0ba8d67..89a92927f 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -143,15 +143,6 @@ func (q *Queries) CreateChannelExtraType(ctx context.Context, arg CreateChannelE return err } -const deleteChannel = `-- name: DeleteChannel :exec -DELETE FROM graph_channels WHERE id = $1 -` - -func (q *Queries) DeleteChannel(ctx context.Context, id int64) error { - _, err := q.db.ExecContext(ctx, deleteChannel, id) - return err -} - const deleteChannelPolicyExtraTypes = `-- name: DeleteChannelPolicyExtraTypes :exec DELETE FROM graph_channel_policy_extra_types WHERE channel_policy_id = $1 @@ -162,6 +153,26 @@ func (q *Queries) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPoli return err } +const deleteChannels = `-- name: DeleteChannels :exec +DELETE FROM graph_channels +WHERE id IN (/*SLICE:ids*/?) +` + +func (q *Queries) DeleteChannels(ctx context.Context, ids []int64) error { + query := deleteChannels + var queryParams []interface{} + if len(ids) > 0 { + for _, v := range ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + _, err := q.db.ExecContext(ctx, query, queryParams...) + return err +} + const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec DELETE FROM graph_node_extra_types WHERE node_id = $1 diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index cd32dc75b..9b9d010b4 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -18,8 +18,8 @@ type Querier interface { CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) CreateChannelExtraType(ctx context.Context, arg CreateChannelExtraTypeParams) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) - DeleteChannel(ctx context.Context, id int64) error DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error + DeleteChannels(ctx context.Context, ids []int64) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteNode(ctx context.Context, id int64) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 52c09e23d..8551a7706 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -569,8 +569,9 @@ WHERE c.version = $1 AND c.id > $2 ORDER BY c.id LIMIT $3; --- name: DeleteChannel :exec -DELETE FROM graph_channels WHERE id = $1; +-- name: DeleteChannels :exec +DELETE FROM graph_channels +WHERE id IN (sqlc.slice('ids')/*SLICE:ids*/); /* ───────────────────────────────────────────── graph_channel_features table queries