diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 89932f83a..30ca0d67a 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -462,7 +462,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, hopHintsCfg := newSelectHopHintsCfg(cfg, totalHopHints) hopHints, err := PopulateHopHints( - hopHintsCfg, amtMSat, invoice.RouteHints, + ctx, hopHintsCfg, amtMSat, invoice.RouteHints, ) if err != nil { return nil, nil, fmt.Errorf("unable to populate hop "+ @@ -624,10 +624,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // chanCanBeHopHint returns true if the target channel is eligible to be a hop // hint. -func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( - *models.ChannelEdgePolicy, bool) { - - ctx := context.TODO() +func chanCanBeHopHint(ctx context.Context, channel *HopHintInfo, + cfg *SelectHopHintsCfg) (*models.ChannelEdgePolicy, bool) { // Since we're only interested in our private channels, we'll skip // public ones. @@ -862,7 +860,7 @@ func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, // shouldIncludeChannel returns true if the channel passes all the checks to // be a hopHint in a given invoice. -func shouldIncludeChannel(cfg *SelectHopHintsCfg, +func shouldIncludeChannel(ctx context.Context, cfg *SelectHopHintsCfg, channel *channeldb.OpenChannel, alreadyIncluded map[uint64]bool) (zpay32.HopHint, lnwire.MilliSatoshi, bool) { @@ -878,7 +876,7 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, hopHintInfo := newHopHintInfo(channel, cfg.IsChannelActive(chanID)) // If this channel can't be a hop hint, then skip it. - edgePolicy, canBeHopHint := chanCanBeHopHint(hopHintInfo, cfg) + edgePolicy, canBeHopHint := chanCanBeHopHint(ctx, hopHintInfo, cfg) if edgePolicy == nil || !canBeHopHint { return zpay32.HopHint{}, 0, false } @@ -907,7 +905,7 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, // // NOTE: selectHopHints expects potentialHints to be already sorted in // descending priority. -func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, +func selectHopHints(ctx context.Context, cfg *SelectHopHintsCfg, nHintsLeft int, targetBandwidth lnwire.MilliSatoshi, potentialHints []*channeldb.OpenChannel, alreadyIncluded map[uint64]bool) [][]zpay32.HopHint { @@ -923,7 +921,7 @@ func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, } hopHint, remoteBalance, include := shouldIncludeChannel( - cfg, channel, alreadyIncluded, + ctx, cfg, channel, alreadyIncluded, ) if include { @@ -951,8 +949,9 @@ func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, // options that'll append the route hint to the set of all route hints. // // TODO(roasbeef): do proper sub-set sum max hints usually << numChans. -func PopulateHopHints(cfg *SelectHopHintsCfg, amtMSat lnwire.MilliSatoshi, - forcedHints [][]zpay32.HopHint) ([][]zpay32.HopHint, error) { +func PopulateHopHints(ctx context.Context, cfg *SelectHopHintsCfg, + amtMSat lnwire.MilliSatoshi, forcedHints [][]zpay32.HopHint) ( + [][]zpay32.HopHint, error) { hopHints := forcedHints @@ -974,7 +973,7 @@ func PopulateHopHints(cfg *SelectHopHintsCfg, amtMSat lnwire.MilliSatoshi, targetBandwidth := amtMSat * hopHintFactor selectedHints := selectHopHints( - cfg, nHintsLeft, targetBandwidth, potentialHints, + ctx, cfg, nHintsLeft, targetBandwidth, potentialHints, alreadyIncluded, ) diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index d9f2d6da8..8aadff354 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -435,6 +435,7 @@ var shouldIncludeChannelTestCases = []struct { }} func TestShouldIncludeChannel(t *testing.T) { + ctx := context.Background() for _, tc := range shouldIncludeChannelTestCases { tc := tc @@ -456,7 +457,7 @@ func TestShouldIncludeChannel(t *testing.T) { } hopHint, remoteBalance, include := shouldIncludeChannel( - cfg, tc.channel, tc.alreadyIncluded, + ctx, cfg, tc.channel, tc.alreadyIncluded, ) require.Equal(t, tc.include, include) @@ -868,6 +869,7 @@ func setupMockTwoChannels(h *hopHintsConfigMock) (lnwire.ChannelID, } func TestPopulateHopHints(t *testing.T) { + ctx := context.Background() for _, tc := range populateHopHintsTestCases { tc := tc @@ -890,7 +892,7 @@ func TestPopulateHopHints(t *testing.T) { MaxHopHints: tc.maxHopHints, } hopHints, err := PopulateHopHints( - cfg, tc.amount, tc.forcedHints, + ctx, cfg, tc.amount, tc.forcedHints, ) require.NoError(t, err) // We shuffle the elements in the hop hint list so we