diff --git a/cmd/commands/cmd_payments.go b/cmd/commands/cmd_payments.go index b746dbbfd..d13b52da2 100644 --- a/cmd/commands/cmd_payments.go +++ b/cmd/commands/cmd_payments.go @@ -183,13 +183,13 @@ func PaymentFlags() []cli.Flag { cancelableFlag, cltvLimitFlag, lastHopFlag, - cli.Int64SliceFlag{ + cli.StringSliceFlag{ Name: "outgoing_chan_id", Usage: "short channel id of the outgoing channel to " + "use for the first hop of the payment; can " + "be specified multiple times in the same " + "command", - Value: &cli.Int64Slice{}, + Value: &cli.StringSlice{}, }, cli.BoolFlag{ Name: "force, f", @@ -521,12 +521,11 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest, lnClient := lnrpc.NewLightningClient(lnConn) - outChan := ctx.Int64Slice("outgoing_chan_id") - if len(outChan) != 0 { - req.OutgoingChanIds = make([]uint64, len(outChan)) - for i, c := range outChan { - req.OutgoingChanIds[i] = uint64(c) - } + var err error + outChan := ctx.StringSlice("outgoing_chan_id") + req.OutgoingChanIds, err = parseChanIDs(outChan) + if err != nil { + return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err) } if ctx.IsSet(lastHopFlag.Name) { @@ -1282,17 +1281,9 @@ func queryRoutes(ctx *cli.Context) error { } outgoingChanIds := ctx.StringSlice("outgoing_chan_id") - if len(outgoingChanIds) != 0 { - req.OutgoingChanIds = make([]uint64, len(outgoingChanIds)) - for i, chanID := range outgoingChanIds { - id, err := strconv.ParseUint(chanID, 10, 64) - if err != nil { - return fmt.Errorf("invalid outgoing_chan_id "+ - "argument: %w", err) - } - - req.OutgoingChanIds[i] = id - } + req.OutgoingChanIds, err = parseChanIDs(outgoingChanIds) + if err != nil { + return fmt.Errorf("unable to decode outgoing_chan_id: %w", err) } if ctx.IsSet("route_hints") { @@ -1585,13 +1576,13 @@ var forwardingHistoryCommand = cli.Command{ Usage: "skip the peer alias lookup per forwarding " + "event in order to improve performance", }, - cli.Int64SliceFlag{ + cli.StringSliceFlag{ Name: "incoming_chan_ids", Usage: "the short channel id of the incoming " + "channel to filter events by; can be " + "specified multiple times in the same command", }, - cli.Int64SliceFlag{ + cli.StringSliceFlag{ Name: "outgoing_chan_ids", Usage: "the short channel id of the outgoing " + "channel to filter events by; can be " + @@ -1677,21 +1668,19 @@ func forwardingHistory(ctx *cli.Context) error { NumMaxEvents: maxEvents, PeerAliasLookup: lookupPeerAlias, } - outgoingChannelIDs := ctx.Int64Slice("outgoing_chan_ids") - if len(outgoingChannelIDs) != 0 { - req.OutgoingChanIds = make([]uint64, len(outgoingChannelIDs)) - for i, c := range outgoingChannelIDs { - req.OutgoingChanIds[i] = uint64(c) - } + + outgoingChannelIDs := ctx.StringSlice("outgoing_chan_ids") + req.OutgoingChanIds, err = parseChanIDs(outgoingChannelIDs) + if err != nil { + return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err) } - incomingChannelIDs := ctx.Int64Slice("incoming_chan_ids") - if len(incomingChannelIDs) != 0 { - req.IncomingChanIds = make([]uint64, len(incomingChannelIDs)) - for i, c := range incomingChannelIDs { - req.IncomingChanIds[i] = uint64(c) - } + incomingChannelIDs := ctx.StringSlice("incoming_chan_ids") + req.IncomingChanIds, err = parseChanIDs(incomingChannelIDs) + if err != nil { + return fmt.Errorf("unable to decode incoming_chan_ids: %w", err) } + resp, err := client.ForwardingHistory(ctxc, req) if err != nil { return err @@ -2060,3 +2049,24 @@ func ordinalNumber(num uint32) string { return fmt.Sprintf("%dth", num) } } + +// parseChanIDs parses a slice of strings containing short channel IDs into a +// slice of uint64 values. +func parseChanIDs(idStrings []string) ([]uint64, error) { + // Return early if no chan IDs are passed. + if len(idStrings) == 0 { + return nil, nil + } + + chanIDs := make([]uint64, len(idStrings)) + for i, idStr := range idStrings { + scid, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return nil, err + } + + chanIDs[i] = scid + } + + return chanIDs, nil +} diff --git a/cmd/commands/commands_test.go b/cmd/commands/commands_test.go index c6baad258..12e861cd5 100644 --- a/cmd/commands/commands_test.go +++ b/cmd/commands/commands_test.go @@ -434,3 +434,61 @@ func TestParseBlockHeightInputs(t *testing.T) { }) } } + +// TestParseChanIDs tests the parseChanIDs function with various +// valid and invalid input values and verifies the output. +func TestParseChanIDs(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanIDs []string + expected []uint64 + expectedErr bool + }{ + { + name: "valid chan ids", + chanIDs: []string{ + "1499733860352000", "17592186044552773633", + }, + expected: []uint64{ + 1499733860352000, 17592186044552773633, + }, + expectedErr: false, + }, + { + name: "invalid chan id", + chanIDs: []string{ + "channel id", + }, + expected: []uint64{}, + expectedErr: true, + }, + { + name: "negative chan id", + chanIDs: []string{ + "-10000", + }, + expected: []uint64{}, + expectedErr: true, + }, + { + name: "empty chan ids", + chanIDs: []string{}, + expected: nil, + expectedErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + chanIDs, err := parseChanIDs(tc.chanIDs) + if tc.expectedErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, chanIDs) + }) + } +}