From f0d2d1fd0ac2ad190bbefbff162aca891195c692 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 16 Jul 2025 08:29:40 +0200 Subject: [PATCH] sqldb: demonstrate the use of ExecutePagedQuery Here, a new query (GetChannelsByOutpoints) is added which makes use of the /*SLICE:outpoints*/ directive & added workaround. This is then used in a test to demonstrate how the ExecutePagedQuery helper can be used to wrap a query like this such that calls are done in pages. The query that has been added will also be used by live code paths in an upcoming commit. --- sqldb/paginate_test.go | 79 ++++++++++++++++++++++++++++++++++++ sqldb/postgres_test.go | 8 ++++ sqldb/sqlc/graph.sql.go | 68 +++++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 11 +++++ sqldb/sqlite_test.go | 8 ++++ 6 files changed, 175 insertions(+) diff --git a/sqldb/paginate_test.go b/sqldb/paginate_test.go index 33b9cd93e..0ebf25371 100644 --- a/sqldb/paginate_test.go +++ b/sqldb/paginate_test.go @@ -1,3 +1,5 @@ +//go:build test_db_postgres || test_db_sqlite + package sqldb import ( @@ -6,6 +8,7 @@ import ( "fmt" "testing" + "github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/stretchr/testify/require" ) @@ -234,3 +237,79 @@ func TestExecutePagedQuery(t *testing.T) { "starting at 2: second page failed") }) } + +// TestSQLSliceQueries tests ExecutePageQuery helper by first showing that a +// query the /*SLICE:*/ directive has a maximum number of +// parameters it can handle, and then showing that the paginated version which +// uses ExecutePagedQuery instead of a raw query can handle more parameters by +// executing the query in pages. +func TestSQLSliceQueries(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := NewTestDB(t) + + // Increase the number of query strings by an order of magnitude each + // iteration until we hit the limit of the backing DB. + // + // NOTE: from testing, the following limits have been noted: + // - for Postgres, the limit is 65535 parameters. + // - for SQLite, the limit is 32766 parameters. + x := 10 + var queryParams []string + for { + for len(queryParams) < x { + queryParams = append( + queryParams, + fmt.Sprintf("%d", len(queryParams)), + ) + } + + _, err := db.GetChannelsByOutpoints(ctx, queryParams) + if err != nil { + if isSQLite { + require.ErrorContains( + t, err, "SQL logic error: too many "+ + "SQL variables", + ) + } else { + require.ErrorContains( + t, err, "extended protocol limited "+ + "to 65535 parameters", + ) + } + break + } + + x *= 10 + + // Just to make sure that the test doesn't carry on too long, + // we assert that we don't exceed a reasonable limit. + require.LessOrEqual(t, x, 100000) + } + + // Now that we have found the limit that the raw query can handle, we + // switch to the wrapped version which will perform the query in pages + // so that the limit is not hit. We use the same number of query params + // that caused the error above. + queryWrapper := func(ctx context.Context, + pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow, + error) { + + return db.GetChannelsByOutpoints(ctx, pageOutpoints) + } + + err := ExecutePagedQuery( + ctx, + DefaultPagedQueryConfig(), + queryParams, + func(s string) string { + return s + }, + queryWrapper, + func(context.Context, sqlc.GetChannelsByOutpointsRow) error { + return nil + }, + ) + require.NoError(t, err) +} diff --git a/sqldb/postgres_test.go b/sqldb/postgres_test.go index 22a29f885..cbeb8ca68 100644 --- a/sqldb/postgres_test.go +++ b/sqldb/postgres_test.go @@ -7,6 +7,14 @@ import ( "testing" ) +// isSQLite is false if the build tag is set to test_db_postgres. It is used in +// tests that compile for both SQLite and Postgres databases to determine +// which database implementation is being used. +// +// TODO(elle): once we've updated to using sqldbv2, we can remove this since +// then we will have access to the DatabaseType on the BaseDB struct at runtime. +const isSQLite = false + // NewTestDB is a helper function that creates a Postgres database for testing. func NewTestDB(t *testing.T) *PostgresStore { pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 79c393880..4e61aa7eb 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -8,6 +8,7 @@ package sqlc import ( "context" "database/sql" + "strings" ) const addSourceNode = `-- name: AddSourceNode :exec @@ -881,6 +882,73 @@ func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannel return items, nil } +const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id +WHERE c.outpoint IN + (/*SLICE:outpoints*/?) +` + +type GetChannelsByOutpointsRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte +} + +func (q *Queries) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) { + query := getChannelsByOutpoints + var queryParams []interface{} + if len(outpoints) > 0 { + for _, v := range outpoints { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:outpoints*/?", makeQueryParams(len(queryParams), len(outpoints)), 1) + } else { + query = strings.Replace(query, "/*SLICE:outpoints*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByOutpointsRow + for rows.Next() { + var i GetChannelsByOutpointsRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.Node1Pubkey, + &i.Node2Pubkey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index ba84db5ac..8155b3de8 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -42,6 +42,7 @@ type Querier interface { GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (GraphChannelPolicy, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) + GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5088b37ee..120d06de8 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -231,6 +231,17 @@ WHERE scid >= @start_scid SELECT * FROM graph_channels WHERE scid = $1 AND version = $2; +-- name: GetChannelsByOutpoints :many +SELECT + sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id +WHERE c.outpoint IN + (sqlc.slice('outpoints')/*SLICE:outpoints*/); + -- name: GetChannelByOutpoint :one SELECT sqlc.embed(c), diff --git a/sqldb/sqlite_test.go b/sqldb/sqlite_test.go index 9dfb875ea..87fdaea04 100644 --- a/sqldb/sqlite_test.go +++ b/sqldb/sqlite_test.go @@ -7,6 +7,14 @@ import ( "testing" ) +// isSQLite is true if the build tag is set to test_db_sqlite. It is used in +// tests that compile for both SQLite and Postgres databases to determine +// which database implementation is being used. +// +// TODO(elle): once we've updated to using sqldbv2, we can remove this since +// then we will have access to the DatabaseType on the BaseDB struct at runtime. +const isSQLite = true + // NewTestDB is a helper function that creates an SQLite database for testing. func NewTestDB(t *testing.T) *SqliteStore { return NewTestSqliteDB(t)