Merge pull request #10100 from Abdulkbk/fix-chanid-flag

commands: fix how we parse chan ids args at CLI level
This commit is contained in:
Oliver Gugger
2025-07-31 06:49:46 -06:00
committed by GitHub
2 changed files with 101 additions and 33 deletions

View File

@@ -183,13 +183,13 @@ func PaymentFlags() []cli.Flag {
cancelableFlag,
cltvLimitFlag,
lastHopFlag,
cli.Int64SliceFlag{
cli.StringSliceFlag{
Name: "outgoing_chan_id",
Usage: "short channel id of the outgoing channel to " +
"use for the first hop of the payment; can " +
"be specified multiple times in the same " +
"command",
Value: &cli.Int64Slice{},
Value: &cli.StringSlice{},
},
cli.BoolFlag{
Name: "force, f",
@@ -521,12 +521,11 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest,
lnClient := lnrpc.NewLightningClient(lnConn)
outChan := ctx.Int64Slice("outgoing_chan_id")
if len(outChan) != 0 {
req.OutgoingChanIds = make([]uint64, len(outChan))
for i, c := range outChan {
req.OutgoingChanIds[i] = uint64(c)
}
var err error
outChan := ctx.StringSlice("outgoing_chan_id")
req.OutgoingChanIds, err = parseChanIDs(outChan)
if err != nil {
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
}
if ctx.IsSet(lastHopFlag.Name) {
@@ -1282,17 +1281,9 @@ func queryRoutes(ctx *cli.Context) error {
}
outgoingChanIds := ctx.StringSlice("outgoing_chan_id")
if len(outgoingChanIds) != 0 {
req.OutgoingChanIds = make([]uint64, len(outgoingChanIds))
for i, chanID := range outgoingChanIds {
id, err := strconv.ParseUint(chanID, 10, 64)
if err != nil {
return fmt.Errorf("invalid outgoing_chan_id "+
"argument: %w", err)
}
req.OutgoingChanIds[i] = id
}
req.OutgoingChanIds, err = parseChanIDs(outgoingChanIds)
if err != nil {
return fmt.Errorf("unable to decode outgoing_chan_id: %w", err)
}
if ctx.IsSet("route_hints") {
@@ -1585,13 +1576,13 @@ var forwardingHistoryCommand = cli.Command{
Usage: "skip the peer alias lookup per forwarding " +
"event in order to improve performance",
},
cli.Int64SliceFlag{
cli.StringSliceFlag{
Name: "incoming_chan_ids",
Usage: "the short channel id of the incoming " +
"channel to filter events by; can be " +
"specified multiple times in the same command",
},
cli.Int64SliceFlag{
cli.StringSliceFlag{
Name: "outgoing_chan_ids",
Usage: "the short channel id of the outgoing " +
"channel to filter events by; can be " +
@@ -1677,21 +1668,19 @@ func forwardingHistory(ctx *cli.Context) error {
NumMaxEvents: maxEvents,
PeerAliasLookup: lookupPeerAlias,
}
outgoingChannelIDs := ctx.Int64Slice("outgoing_chan_ids")
if len(outgoingChannelIDs) != 0 {
req.OutgoingChanIds = make([]uint64, len(outgoingChannelIDs))
for i, c := range outgoingChannelIDs {
req.OutgoingChanIds[i] = uint64(c)
}
outgoingChannelIDs := ctx.StringSlice("outgoing_chan_ids")
req.OutgoingChanIds, err = parseChanIDs(outgoingChannelIDs)
if err != nil {
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
}
incomingChannelIDs := ctx.Int64Slice("incoming_chan_ids")
if len(incomingChannelIDs) != 0 {
req.IncomingChanIds = make([]uint64, len(incomingChannelIDs))
for i, c := range incomingChannelIDs {
req.IncomingChanIds[i] = uint64(c)
}
incomingChannelIDs := ctx.StringSlice("incoming_chan_ids")
req.IncomingChanIds, err = parseChanIDs(incomingChannelIDs)
if err != nil {
return fmt.Errorf("unable to decode incoming_chan_ids: %w", err)
}
resp, err := client.ForwardingHistory(ctxc, req)
if err != nil {
return err
@@ -2060,3 +2049,24 @@ func ordinalNumber(num uint32) string {
return fmt.Sprintf("%dth", num)
}
}
// parseChanIDs parses a slice of strings containing short channel IDs into a
// slice of uint64 values.
func parseChanIDs(idStrings []string) ([]uint64, error) {
// Return early if no chan IDs are passed.
if len(idStrings) == 0 {
return nil, nil
}
chanIDs := make([]uint64, len(idStrings))
for i, idStr := range idStrings {
scid, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return nil, err
}
chanIDs[i] = scid
}
return chanIDs, nil
}

View File

@@ -434,3 +434,61 @@ func TestParseBlockHeightInputs(t *testing.T) {
})
}
}
// TestParseChanIDs tests the parseChanIDs function with various
// valid and invalid input values and verifies the output.
func TestParseChanIDs(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
chanIDs []string
expected []uint64
expectedErr bool
}{
{
name: "valid chan ids",
chanIDs: []string{
"1499733860352000", "17592186044552773633",
},
expected: []uint64{
1499733860352000, 17592186044552773633,
},
expectedErr: false,
},
{
name: "invalid chan id",
chanIDs: []string{
"channel id",
},
expected: []uint64{},
expectedErr: true,
},
{
name: "negative chan id",
chanIDs: []string{
"-10000",
},
expected: []uint64{},
expectedErr: true,
},
{
name: "empty chan ids",
chanIDs: []string{},
expected: nil,
expectedErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
chanIDs, err := parseChanIDs(tc.chanIDs)
if tc.expectedErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expected, chanIDs)
})
}
}