mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-07 03:06:01 +02:00
scripts: add sql slices workaround to sqlc gen script
This copies the workaround introduced in the taproot-assets code base and will allow us to use `WHERE x in <list>` type queries.
This commit is contained in:
@@ -46,4 +46,26 @@ docker run \
|
||||
-e UID=$UID \
|
||||
-v "$DIR/../:/build" \
|
||||
-w /build \
|
||||
"sqlc/sqlc:${SQLC_VERSION}" generate
|
||||
"sqlc/sqlc:${SQLC_VERSION}" generate
|
||||
|
||||
# Because we're using the Postgres dialect of sqlc, we can't use sqlc.slice()
|
||||
# normally, because sqlc just thinks it can pass the Golang slice directly to
|
||||
# the database driver. So it doesn't put the /*SLICE:<field_name>*/ workaround
|
||||
# comment into the actual SQL query. But we add the comment ourselves and now
|
||||
# just need to replace the '$X/*SLICE:<field_name>*/' placeholders with the
|
||||
# actual placeholder that's going to be replaced by the sqlc generated code.
|
||||
echo "Applying sqlc.slice() workaround..."
|
||||
for file in sqldb/sqlc/*.sql.go; do
|
||||
echo "Patching $file"
|
||||
|
||||
# First, we replace the `$X/*SLICE:<field_name>*/` placeholders with
|
||||
# the actual placeholder that sqlc will use: `/*SLICE:<field_name>*/?`.
|
||||
sed -i.bak -E 's/\$([0-9]+)\/\*SLICE:([a-zA-Z_][a-zA-Z0-9_]*)\*\//\/\*SLICE:\2\*\/\?/g' "$file"
|
||||
|
||||
# Then, we replace the `strings.Repeat(",?", len(arg.<golang_name>))[1:]` with
|
||||
# a function call that generates the correct number of placeholders:
|
||||
# `makeQueryParams(len(queryParams), len(arg.<golang_name>))`.
|
||||
sed -i.bak -E 's/strings\.Repeat\(",\?", len\(([^)]+)\)\)\[1:\]/makeQueryParams(len(queryParams), len(\1))/g' "$file"
|
||||
|
||||
rm "$file.bak"
|
||||
done
|
||||
|
39
sqldb/sqlc/db_custom.go
Normal file
39
sqldb/sqlc/db_custom.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// makeQueryParams generates a string of query parameters for a SQL query. It is
|
||||
// meant to replace the `?` placeholders in a SQL query with numbered parameters
|
||||
// like `$1`, `$2`, etc. This is required for the sqlc /*SLICE:<field_name>*/
|
||||
// workaround. See scripts/gen_sqlc_docker.sh for more details.
|
||||
func makeQueryParams(numTotalArgs, numListArgs int) string {
|
||||
if numListArgs == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Pre-allocate a rough estimation of the buffer size to avoid
|
||||
// re-allocations. A parameter like $1000, takes 6 bytes.
|
||||
b.Grow(numListArgs * 6)
|
||||
|
||||
diff := numTotalArgs - numListArgs
|
||||
for i := 0; i < numListArgs; i++ {
|
||||
if i > 0 {
|
||||
// We don't need to check the error here because the
|
||||
// WriteString method of strings.Builder always returns
|
||||
// nil.
|
||||
_, _ = b.WriteString(",")
|
||||
}
|
||||
|
||||
// We don't need to check the error here because the
|
||||
// Write method (called by fmt.Fprintf) of strings.Builder
|
||||
// always returns nil.
|
||||
_, _ = fmt.Fprintf(&b, "$%d", i+diff+1)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
67
sqldb/sqlc/db_custom_test.go
Normal file
67
sqldb/sqlc/db_custom_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// BenchmarkMakeQueryParams benchmarks the makeQueryParams function for
|
||||
// various argument sizes. This helps to ensure the function performs
|
||||
// efficiently when generating SQL query parameter strings for different
|
||||
// input sizes.
|
||||
func BenchmarkMakeQueryParams(b *testing.B) {
|
||||
cases := []struct {
|
||||
totalArgs int
|
||||
listArgs int
|
||||
}{
|
||||
{totalArgs: 5, listArgs: 2},
|
||||
{totalArgs: 10, listArgs: 3},
|
||||
{totalArgs: 50, listArgs: 10},
|
||||
{totalArgs: 100, listArgs: 20},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
name := fmt.Sprintf(
|
||||
"totalArgs=%d/listArgs=%d", c.totalArgs,
|
||||
c.listArgs,
|
||||
)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = makeQueryParams(
|
||||
c.totalArgs, c.listArgs,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeQueryParams tests the makeQueryParams function for various
|
||||
// argument sizes and verifies the output matches the expected SQL
|
||||
// parameter string. The function is assumed to generate a comma-separated
|
||||
// list of parameters in the form $N for use in SQL queries.
|
||||
func TestMakeQueryParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
totalArgs int
|
||||
listArgs int
|
||||
expected string
|
||||
}{
|
||||
{totalArgs: 5, listArgs: 2, expected: "$4,$5"},
|
||||
{totalArgs: 10, listArgs: 3, expected: "$8,$9,$10"},
|
||||
{totalArgs: 1, listArgs: 1, expected: "$1"},
|
||||
{totalArgs: 3, listArgs: 0, expected: ""},
|
||||
{totalArgs: 4, listArgs: 4, expected: "$1,$2,$3,$4"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := makeQueryParams(tc.totalArgs, tc.listArgs)
|
||||
require.Equal(
|
||||
t, tc.expected, result,
|
||||
"unexpected result for totalArgs=%d, "+
|
||||
"listArgs=%d", tc.totalArgs, tc.listArgs,
|
||||
)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user