From bb398456b557054735a3654cd8ec16e81b897022 Mon Sep 17 00:00:00 2001 From: Sam Korn Date: Tue, 18 Mar 2025 11:19:48 -0600 Subject: [PATCH] cmd: more input parameters checks for listchaintxns cli command add a parsing block height function and error when block heights would produce invalid results --- cmd/commands/commands.go | 60 ++++++++++++------- cmd/commands/commands_test.go | 106 ++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 20 deletions(-) diff --git a/cmd/commands/commands.go b/cmd/commands/commands.go index 7969cd20a..9cbd0bf7a 100644 --- a/cmd/commands/commands.go +++ b/cmd/commands/commands.go @@ -2228,9 +2228,11 @@ var listChainTxnsCommand = cli.Command{ cli.Int64Flag{ Name: "end_height", Usage: "the block height until which to list " + - "transactions, inclusive, to get " + - "transactions until the chain tip, including " + - "unconfirmed, set this value to -1", + "transactions, inclusive; by default this " + + "will return all transactions up to the " + + "chain tip including unconfirmed " + + "transactions", + Value: -1, }, cli.UintFlag{ Name: "index_offset", @@ -2241,7 +2243,7 @@ var listChainTxnsCommand = cli.Command{ }, cli.IntFlag{ Name: "max_transactions", - Usage: "(optional) the max number of transactions to " + + Usage: "the max number of transactions to " + "return; leave at default of 0 to return " + "all transactions", Value: 0, @@ -2251,33 +2253,51 @@ var listChainTxnsCommand = cli.Command{ List all transactions an address of the wallet was involved in. This call will return a list of wallet related transactions that paid - to an address our wallet controls, or spent utxos that we held. The - start_height and end_height flags can be used to specify an inclusive - block range over which to query for transactions. If the end_height is - less than the start_height, transactions will be queried in reverse. - To get all transactions until the chain tip, including unconfirmed - transactions (identifiable with BlockHeight=0), set end_height to -1. - By default, this call will get all transactions our wallet was involved - in, including unconfirmed transactions. -`, + to an address our wallet controls, or spent utxos that we held. + + By default, this call will get all transactions until the chain tip, + including unconfirmed transactions (end_height=-1).`, Action: actionDecorator(listChainTxns), } +func parseBlockHeightInputs(ctx *cli.Context) (int32, int32, error) { + startHeight := int32(ctx.Int64("start_height")) + endHeight := int32(ctx.Int64("end_height")) + + if ctx.IsSet("start_height") && ctx.IsSet("end_height") { + if endHeight != -1 && startHeight > endHeight { + return startHeight, endHeight, + errors.New("start_height should " + + "be less than end_height if " + + "end_height is not equal to -1") + } + } + + if startHeight < 0 { + return startHeight, endHeight, + errors.New("start_height should " + + "be greater than or " + + "equal to 0") + } + + return startHeight, endHeight, nil +} + func listChainTxns(ctx *cli.Context) error { ctxc := getContext() client, cleanUp := getClient(ctx) defer cleanUp() + startHeight, endHeight, err := parseBlockHeightInputs(ctx) + if err != nil { + return err + } + req := &lnrpc.GetTransactionsRequest{ IndexOffset: uint32(ctx.Uint64("index_offset")), MaxTransactions: uint32(ctx.Uint64("max_transactions")), - } - - if ctx.IsSet("start_height") { - req.StartHeight = int32(ctx.Int64("start_height")) - } - if ctx.IsSet("end_height") { - req.EndHeight = int32(ctx.Int64("end_height")) + StartHeight: startHeight, + EndHeight: endHeight, } resp, err := client.GetTransactions(ctxc, req) diff --git a/cmd/commands/commands_test.go b/cmd/commands/commands_test.go index 61bd12920..c6baad258 100644 --- a/cmd/commands/commands_test.go +++ b/cmd/commands/commands_test.go @@ -2,12 +2,14 @@ package commands import ( "encoding/hex" + "flag" "fmt" "math" "strconv" "testing" "github.com/stretchr/testify/require" + "github.com/urfave/cli" ) // TestParseChanPoint tests parseChanPoint with various @@ -328,3 +330,107 @@ func TestAppendChanID(t *testing.T) { }) } } + +// TestParseBlockHeightInputs tests the input block heights and ensure that the +// proper errors are returned when they could lead in invalid results. +func TestParseBlockHeightInputs(t *testing.T) { + t.Parallel() + + app := cli.NewApp() + + startDefault, endDefault := int64(0), int64(-1) + + testCases := []struct { + name string + expectedStart int32 + expectedEnd int32 + expectedErr string + }{ + { + name: "start less than end", + expectedStart: 100, + expectedEnd: 200, + expectedErr: "", + }, + { + name: "start greater than end", + expectedStart: 200, + expectedEnd: 100, + expectedErr: "start_height should be " + + "less than end_height if end_height is " + + "not equal to -1", + }, + { + name: "only start height set", + expectedStart: 100, + expectedEnd: -1, + expectedErr: "", + }, + { + name: "start set and end height set as -1", + expectedStart: 100, + expectedEnd: -1, + expectedErr: "", + }, + { + name: "neither start nor end heights defined", + expectedStart: 0, + expectedEnd: -1, + expectedErr: "", + }, + { + name: "only end height defined", + expectedStart: 0, + expectedEnd: 100, + expectedErr: "", + }, + { + name: "start height is a negative", + expectedStart: -1, + expectedEnd: 100, + expectedErr: "start_height should be greater " + + "than or equal to 0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + flagSet := flag.NewFlagSet( + "listchaintxns", flag.ContinueOnError, + ) + + var startHeight, endHeight int64 + flagSet.Int64Var( + &startHeight, "start_height", startDefault, "", + ) + flagSet.Int64Var( + &endHeight, "end_height", endDefault, "", + ) + + err := flagSet.Set( + "start_height", + strconv.Itoa(int(tc.expectedStart)), + ) + require.NoError( + t, err, "failed to set start_height flag", + ) + + err = flagSet.Set( + "end_height", strconv.Itoa(int(tc.expectedEnd)), + ) + require.NoError( + t, err, "failed to set end_height flag", + ) + + ctx := cli.NewContext(app, flagSet, nil) + start, end, err := parseBlockHeightInputs(ctx) + if tc.expectedErr != "" { + require.EqualError(t, err, tc.expectedErr) + } else { + require.NoError(t, err) + } + require.Equal(t, tc.expectedStart, start) + require.Equal(t, tc.expectedEnd, end) + }) + } +}