diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh index 2520b7118..6d728fc42 100755 --- a/scripts/gen_sqlc_docker.sh +++ b/scripts/gen_sqlc_docker.sh @@ -46,4 +46,26 @@ docker run \ -e UID=$UID \ -v "$DIR/../:/build" \ -w /build \ - "sqlc/sqlc:${SQLC_VERSION}" generate \ No newline at end of file + "sqlc/sqlc:${SQLC_VERSION}" generate + +# Because we're using the Postgres dialect of sqlc, we can't use sqlc.slice() +# normally, because sqlc just thinks it can pass the Golang slice directly to +# the database driver. So it doesn't put the /*SLICE:*/ workaround +# comment into the actual SQL query. But we add the comment ourselves and now +# just need to replace the '$X/*SLICE:*/' placeholders with the +# actual placeholder that's going to be replaced by the sqlc generated code. +echo "Applying sqlc.slice() workaround..." +for file in sqldb/sqlc/*.sql.go; do + echo "Patching $file" + + # First, we replace the `$X/*SLICE:*/` placeholders with + # the actual placeholder that sqlc will use: `/*SLICE:*/?`. + sed -i.bak -E 's/\$([0-9]+)\/\*SLICE:([a-zA-Z_][a-zA-Z0-9_]*)\*\//\/\*SLICE:\2\*\/\?/g' "$file" + + # Then, we replace the `strings.Repeat(",?", len(arg.))[1:]` with + # a function call that generates the correct number of placeholders: + # `makeQueryParams(len(queryParams), len(arg.))`. + sed -i.bak -E 's/strings\.Repeat\(",\?", len\(([^)]+)\)\)\[1:\]/makeQueryParams(len(queryParams), len(\1))/g' "$file" + + rm "$file.bak" +done diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go new file mode 100644 index 000000000..2490e5feb --- /dev/null +++ b/sqldb/sqlc/db_custom.go @@ -0,0 +1,39 @@ +package sqlc + +import ( + "fmt" + "strings" +) + +// makeQueryParams generates a string of query parameters for a SQL query. It is +// meant to replace the `?` placeholders in a SQL query with numbered parameters +// like `$1`, `$2`, etc. This is required for the sqlc /*SLICE:*/ +// workaround. See scripts/gen_sqlc_docker.sh for more details. +func makeQueryParams(numTotalArgs, numListArgs int) string { + if numListArgs == 0 { + return "" + } + + var b strings.Builder + + // Pre-allocate a rough estimation of the buffer size to avoid + // re-allocations. A parameter like $1000, takes 6 bytes. + b.Grow(numListArgs * 6) + + diff := numTotalArgs - numListArgs + for i := 0; i < numListArgs; i++ { + if i > 0 { + // We don't need to check the error here because the + // WriteString method of strings.Builder always returns + // nil. + _, _ = b.WriteString(",") + } + + // We don't need to check the error here because the + // Write method (called by fmt.Fprintf) of strings.Builder + // always returns nil. + _, _ = fmt.Fprintf(&b, "$%d", i+diff+1) + } + + return b.String() +} diff --git a/sqldb/sqlc/db_custom_test.go b/sqldb/sqlc/db_custom_test.go new file mode 100644 index 000000000..7db627c88 --- /dev/null +++ b/sqldb/sqlc/db_custom_test.go @@ -0,0 +1,67 @@ +package sqlc + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// BenchmarkMakeQueryParams benchmarks the makeQueryParams function for +// various argument sizes. This helps to ensure the function performs +// efficiently when generating SQL query parameter strings for different +// input sizes. +func BenchmarkMakeQueryParams(b *testing.B) { + cases := []struct { + totalArgs int + listArgs int + }{ + {totalArgs: 5, listArgs: 2}, + {totalArgs: 10, listArgs: 3}, + {totalArgs: 50, listArgs: 10}, + {totalArgs: 100, listArgs: 20}, + } + + for _, c := range cases { + name := fmt.Sprintf( + "totalArgs=%d/listArgs=%d", c.totalArgs, + c.listArgs, + ) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = makeQueryParams( + c.totalArgs, c.listArgs, + ) + } + }) + } +} + +// TestMakeQueryParams tests the makeQueryParams function for various +// argument sizes and verifies the output matches the expected SQL +// parameter string. The function is assumed to generate a comma-separated +// list of parameters in the form $N for use in SQL queries. +func TestMakeQueryParams(t *testing.T) { + t.Parallel() + + testCases := []struct { + totalArgs int + listArgs int + expected string + }{ + {totalArgs: 5, listArgs: 2, expected: "$4,$5"}, + {totalArgs: 10, listArgs: 3, expected: "$8,$9,$10"}, + {totalArgs: 1, listArgs: 1, expected: "$1"}, + {totalArgs: 3, listArgs: 0, expected: ""}, + {totalArgs: 4, listArgs: 4, expected: "$1,$2,$3,$4"}, + } + + for _, tc := range testCases { + result := makeQueryParams(tc.totalArgs, tc.listArgs) + require.Equal( + t, tc.expected, result, + "unexpected result for totalArgs=%d, "+ + "listArgs=%d", tc.totalArgs, tc.listArgs, + ) + } +}