mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-25 14:50:43 +02:00
Merge pull request #10100 from Abdulkbk/fix-chanid-flag
commands: fix how we parse chan ids args at CLI level
This commit is contained in:
@@ -183,13 +183,13 @@ func PaymentFlags() []cli.Flag {
|
||||
cancelableFlag,
|
||||
cltvLimitFlag,
|
||||
lastHopFlag,
|
||||
cli.Int64SliceFlag{
|
||||
cli.StringSliceFlag{
|
||||
Name: "outgoing_chan_id",
|
||||
Usage: "short channel id of the outgoing channel to " +
|
||||
"use for the first hop of the payment; can " +
|
||||
"be specified multiple times in the same " +
|
||||
"command",
|
||||
Value: &cli.Int64Slice{},
|
||||
Value: &cli.StringSlice{},
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "force, f",
|
||||
@@ -521,12 +521,11 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest,
|
||||
|
||||
lnClient := lnrpc.NewLightningClient(lnConn)
|
||||
|
||||
outChan := ctx.Int64Slice("outgoing_chan_id")
|
||||
if len(outChan) != 0 {
|
||||
req.OutgoingChanIds = make([]uint64, len(outChan))
|
||||
for i, c := range outChan {
|
||||
req.OutgoingChanIds[i] = uint64(c)
|
||||
}
|
||||
var err error
|
||||
outChan := ctx.StringSlice("outgoing_chan_id")
|
||||
req.OutgoingChanIds, err = parseChanIDs(outChan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
|
||||
}
|
||||
|
||||
if ctx.IsSet(lastHopFlag.Name) {
|
||||
@@ -1282,17 +1281,9 @@ func queryRoutes(ctx *cli.Context) error {
|
||||
}
|
||||
|
||||
outgoingChanIds := ctx.StringSlice("outgoing_chan_id")
|
||||
if len(outgoingChanIds) != 0 {
|
||||
req.OutgoingChanIds = make([]uint64, len(outgoingChanIds))
|
||||
for i, chanID := range outgoingChanIds {
|
||||
id, err := strconv.ParseUint(chanID, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid outgoing_chan_id "+
|
||||
"argument: %w", err)
|
||||
}
|
||||
|
||||
req.OutgoingChanIds[i] = id
|
||||
}
|
||||
req.OutgoingChanIds, err = parseChanIDs(outgoingChanIds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode outgoing_chan_id: %w", err)
|
||||
}
|
||||
|
||||
if ctx.IsSet("route_hints") {
|
||||
@@ -1585,13 +1576,13 @@ var forwardingHistoryCommand = cli.Command{
|
||||
Usage: "skip the peer alias lookup per forwarding " +
|
||||
"event in order to improve performance",
|
||||
},
|
||||
cli.Int64SliceFlag{
|
||||
cli.StringSliceFlag{
|
||||
Name: "incoming_chan_ids",
|
||||
Usage: "the short channel id of the incoming " +
|
||||
"channel to filter events by; can be " +
|
||||
"specified multiple times in the same command",
|
||||
},
|
||||
cli.Int64SliceFlag{
|
||||
cli.StringSliceFlag{
|
||||
Name: "outgoing_chan_ids",
|
||||
Usage: "the short channel id of the outgoing " +
|
||||
"channel to filter events by; can be " +
|
||||
@@ -1677,21 +1668,19 @@ func forwardingHistory(ctx *cli.Context) error {
|
||||
NumMaxEvents: maxEvents,
|
||||
PeerAliasLookup: lookupPeerAlias,
|
||||
}
|
||||
outgoingChannelIDs := ctx.Int64Slice("outgoing_chan_ids")
|
||||
if len(outgoingChannelIDs) != 0 {
|
||||
req.OutgoingChanIds = make([]uint64, len(outgoingChannelIDs))
|
||||
for i, c := range outgoingChannelIDs {
|
||||
req.OutgoingChanIds[i] = uint64(c)
|
||||
}
|
||||
|
||||
outgoingChannelIDs := ctx.StringSlice("outgoing_chan_ids")
|
||||
req.OutgoingChanIds, err = parseChanIDs(outgoingChannelIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
|
||||
}
|
||||
|
||||
incomingChannelIDs := ctx.Int64Slice("incoming_chan_ids")
|
||||
if len(incomingChannelIDs) != 0 {
|
||||
req.IncomingChanIds = make([]uint64, len(incomingChannelIDs))
|
||||
for i, c := range incomingChannelIDs {
|
||||
req.IncomingChanIds[i] = uint64(c)
|
||||
}
|
||||
incomingChannelIDs := ctx.StringSlice("incoming_chan_ids")
|
||||
req.IncomingChanIds, err = parseChanIDs(incomingChannelIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode incoming_chan_ids: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.ForwardingHistory(ctxc, req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2060,3 +2049,24 @@ func ordinalNumber(num uint32) string {
|
||||
return fmt.Sprintf("%dth", num)
|
||||
}
|
||||
}
|
||||
|
||||
// parseChanIDs parses a slice of strings containing short channel IDs into a
|
||||
// slice of uint64 values.
|
||||
func parseChanIDs(idStrings []string) ([]uint64, error) {
|
||||
// Return early if no chan IDs are passed.
|
||||
if len(idStrings) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chanIDs := make([]uint64, len(idStrings))
|
||||
for i, idStr := range idStrings {
|
||||
scid, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chanIDs[i] = scid
|
||||
}
|
||||
|
||||
return chanIDs, nil
|
||||
}
|
||||
|
@@ -434,3 +434,61 @@ func TestParseBlockHeightInputs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseChanIDs tests the parseChanIDs function with various
|
||||
// valid and invalid input values and verifies the output.
|
||||
func TestParseChanIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
chanIDs []string
|
||||
expected []uint64
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid chan ids",
|
||||
chanIDs: []string{
|
||||
"1499733860352000", "17592186044552773633",
|
||||
},
|
||||
expected: []uint64{
|
||||
1499733860352000, 17592186044552773633,
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid chan id",
|
||||
chanIDs: []string{
|
||||
"channel id",
|
||||
},
|
||||
expected: []uint64{},
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative chan id",
|
||||
chanIDs: []string{
|
||||
"-10000",
|
||||
},
|
||||
expected: []uint64{},
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty chan ids",
|
||||
chanIDs: []string{},
|
||||
expected: nil,
|
||||
expectedErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
chanIDs, err := parseChanIDs(tc.chanIDs)
|
||||
if tc.expectedErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expected, chanIDs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user