Merge pull request #10100 from Abdulkbk/fix-chanid-flag

commands: fix how we parse chan ids args at CLI level
This commit is contained in:
Oliver Gugger
2025-07-31 06:49:46 -06:00
committed by GitHub
2 changed files with 101 additions and 33 deletions

View File

@@ -183,13 +183,13 @@ func PaymentFlags() []cli.Flag {
cancelableFlag, cancelableFlag,
cltvLimitFlag, cltvLimitFlag,
lastHopFlag, lastHopFlag,
cli.Int64SliceFlag{ cli.StringSliceFlag{
Name: "outgoing_chan_id", Name: "outgoing_chan_id",
Usage: "short channel id of the outgoing channel to " + Usage: "short channel id of the outgoing channel to " +
"use for the first hop of the payment; can " + "use for the first hop of the payment; can " +
"be specified multiple times in the same " + "be specified multiple times in the same " +
"command", "command",
Value: &cli.Int64Slice{}, Value: &cli.StringSlice{},
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "force, f", Name: "force, f",
@@ -521,12 +521,11 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest,
lnClient := lnrpc.NewLightningClient(lnConn) lnClient := lnrpc.NewLightningClient(lnConn)
outChan := ctx.Int64Slice("outgoing_chan_id") var err error
if len(outChan) != 0 { outChan := ctx.StringSlice("outgoing_chan_id")
req.OutgoingChanIds = make([]uint64, len(outChan)) req.OutgoingChanIds, err = parseChanIDs(outChan)
for i, c := range outChan { if err != nil {
req.OutgoingChanIds[i] = uint64(c) return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
}
} }
if ctx.IsSet(lastHopFlag.Name) { if ctx.IsSet(lastHopFlag.Name) {
@@ -1282,17 +1281,9 @@ func queryRoutes(ctx *cli.Context) error {
} }
outgoingChanIds := ctx.StringSlice("outgoing_chan_id") outgoingChanIds := ctx.StringSlice("outgoing_chan_id")
if len(outgoingChanIds) != 0 { req.OutgoingChanIds, err = parseChanIDs(outgoingChanIds)
req.OutgoingChanIds = make([]uint64, len(outgoingChanIds)) if err != nil {
for i, chanID := range outgoingChanIds { return fmt.Errorf("unable to decode outgoing_chan_id: %w", err)
id, err := strconv.ParseUint(chanID, 10, 64)
if err != nil {
return fmt.Errorf("invalid outgoing_chan_id "+
"argument: %w", err)
}
req.OutgoingChanIds[i] = id
}
} }
if ctx.IsSet("route_hints") { if ctx.IsSet("route_hints") {
@@ -1585,13 +1576,13 @@ var forwardingHistoryCommand = cli.Command{
Usage: "skip the peer alias lookup per forwarding " + Usage: "skip the peer alias lookup per forwarding " +
"event in order to improve performance", "event in order to improve performance",
}, },
cli.Int64SliceFlag{ cli.StringSliceFlag{
Name: "incoming_chan_ids", Name: "incoming_chan_ids",
Usage: "the short channel id of the incoming " + Usage: "the short channel id of the incoming " +
"channel to filter events by; can be " + "channel to filter events by; can be " +
"specified multiple times in the same command", "specified multiple times in the same command",
}, },
cli.Int64SliceFlag{ cli.StringSliceFlag{
Name: "outgoing_chan_ids", Name: "outgoing_chan_ids",
Usage: "the short channel id of the outgoing " + Usage: "the short channel id of the outgoing " +
"channel to filter events by; can be " + "channel to filter events by; can be " +
@@ -1677,21 +1668,19 @@ func forwardingHistory(ctx *cli.Context) error {
NumMaxEvents: maxEvents, NumMaxEvents: maxEvents,
PeerAliasLookup: lookupPeerAlias, PeerAliasLookup: lookupPeerAlias,
} }
outgoingChannelIDs := ctx.Int64Slice("outgoing_chan_ids")
if len(outgoingChannelIDs) != 0 { outgoingChannelIDs := ctx.StringSlice("outgoing_chan_ids")
req.OutgoingChanIds = make([]uint64, len(outgoingChannelIDs)) req.OutgoingChanIds, err = parseChanIDs(outgoingChannelIDs)
for i, c := range outgoingChannelIDs { if err != nil {
req.OutgoingChanIds[i] = uint64(c) return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
}
} }
incomingChannelIDs := ctx.Int64Slice("incoming_chan_ids") incomingChannelIDs := ctx.StringSlice("incoming_chan_ids")
if len(incomingChannelIDs) != 0 { req.IncomingChanIds, err = parseChanIDs(incomingChannelIDs)
req.IncomingChanIds = make([]uint64, len(incomingChannelIDs)) if err != nil {
for i, c := range incomingChannelIDs { return fmt.Errorf("unable to decode incoming_chan_ids: %w", err)
req.IncomingChanIds[i] = uint64(c)
}
} }
resp, err := client.ForwardingHistory(ctxc, req) resp, err := client.ForwardingHistory(ctxc, req)
if err != nil { if err != nil {
return err return err
@@ -2060,3 +2049,24 @@ func ordinalNumber(num uint32) string {
return fmt.Sprintf("%dth", num) return fmt.Sprintf("%dth", num)
} }
} }
// parseChanIDs parses a slice of strings containing short channel IDs into a
// slice of uint64 values.
func parseChanIDs(idStrings []string) ([]uint64, error) {
// Return early if no chan IDs are passed.
if len(idStrings) == 0 {
return nil, nil
}
chanIDs := make([]uint64, len(idStrings))
for i, idStr := range idStrings {
scid, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return nil, err
}
chanIDs[i] = scid
}
return chanIDs, nil
}

View File

@@ -434,3 +434,61 @@ func TestParseBlockHeightInputs(t *testing.T) {
}) })
} }
} }
// TestParseChanIDs tests the parseChanIDs function with various
// valid and invalid input values and verifies the output.
func TestParseChanIDs(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
chanIDs []string
expected []uint64
expectedErr bool
}{
{
name: "valid chan ids",
chanIDs: []string{
"1499733860352000", "17592186044552773633",
},
expected: []uint64{
1499733860352000, 17592186044552773633,
},
expectedErr: false,
},
{
name: "invalid chan id",
chanIDs: []string{
"channel id",
},
expected: []uint64{},
expectedErr: true,
},
{
name: "negative chan id",
chanIDs: []string{
"-10000",
},
expected: []uint64{},
expectedErr: true,
},
{
name: "empty chan ids",
chanIDs: []string{},
expected: nil,
expectedErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
chanIDs, err := parseChanIDs(tc.chanIDs)
if tc.expectedErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expected, chanIDs)
})
}
}