From e9269c2093498b12d2ea387a53b85c47b9ddd34c Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 16 Nov 2022 02:15:24 +0800 Subject: [PATCH] channeldb+lnd: rpc server filters payments by date --- channeldb/payments.go | 32 ++++++++++++++++++++++++++- channeldb/payments_test.go | 45 ++++++++++++++++++++++++++++++++++++++ rpcserver.go | 24 ++++++++++++++++++++ 3 files changed, 100 insertions(+), 1 deletion(-) diff --git a/channeldb/payments.go b/channeldb/payments.go index e1febf70d..aef5ddfa2 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -536,6 +536,14 @@ type PaymentsQuery struct { // CountTotal indicates that all payments currently present in the // payment index (complete and incomplete) should be counted. CountTotal bool + + // CreationDateStart, if set, filters out all payments with a creation + // date greater than or euqal to it. + CreationDateStart time.Time + + // CreationDateEnd, if set, filters out all payments with a creation + // date less than or euqal to it. + CreationDateEnd time.Time } // PaymentsResponse contains the result of a query to the payments database. @@ -570,7 +578,11 @@ type PaymentsResponse struct { // to a subset of payments by the payments query, containing an offset // index and a maximum number of returned payments. func (d *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { - var resp PaymentsResponse + var ( + resp PaymentsResponse + startDateSet = !query.CreationDateStart.IsZero() + endDateSet = !query.CreationDateEnd.IsZero() + ) if err := kvdb.View(d, func(tx kvdb.RTx) error { // Get the root payments bucket. @@ -615,6 +627,24 @@ func (d *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { return false, err } + // Skip any payments that were created before the + // specified time. + if startDateSet && payment.Info.CreationTime.Before( + query.CreationDateStart, + ) { + + return false, nil + } + + // Skip any payments that were created after the + // specified time. + if endDateSet && payment.Info.CreationTime.After( + query.CreationDateEnd, + ) { + + return false, nil + } + // At this point, we've exhausted the offset, so we'll // begin collecting invoices found within the range. resp.Payments = append(resp.Payments, payment) diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index f10d5ef46..a3a4ac9ab 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -391,6 +391,48 @@ func TestQueryPayments(t *testing.T) { lastIndex: 4, expectedSeqNrs: []uint64{3, 4}, }, + { + name: "query in forwards order, with start creation " + + "time", + query: PaymentsQuery{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: time.Unix(0, 5), + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in forwards order, with start creation " + + "time at end, overflow", + query: PaymentsQuery{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: time.Unix(0, 7), + }, + firstIndex: 7, + lastIndex: 7, + expectedSeqNrs: []uint64{7}, + }, + { + name: "query with start and end creation time", + query: PaymentsQuery{ + IndexOffset: 9, + MaxPayments: math.MaxUint64, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: time.Unix(0, 3), + CreationDateEnd: time.Unix(0, 5), + }, + firstIndex: 3, + lastIndex: 5, + expectedSeqNrs: []uint64{3, 4, 5}, + }, } for _, tt := range tests { @@ -426,6 +468,9 @@ func TestQueryPayments(t *testing.T) { t.Fatalf("unable to create test "+ "payment: %v", err) } + // Override creation time to allow for testing + // of CreationDateStart and CreationDateEnd. + info.CreationTime = time.Unix(0, int64(i+1)) // Create a new payment entry in the database. err = pControl.InitPayment(info.PaymentIdentifier, info) diff --git a/rpcserver.go b/rpcserver.go index fb3bcc0c3..941d81773 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6359,6 +6359,16 @@ func (r *rpcServer) ListPayments(ctx context.Context, rpcsLog.Debugf("[ListPayments]") + // If both dates are set, we check that the start date is less than the + // end date, otherwise we'll get an empty result. + if req.CreationDateStart != 0 && req.CreationDateEnd != 0 { + if req.CreationDateStart >= req.CreationDateEnd { + return nil, fmt.Errorf("start date(%v) must be before "+ + "end date(%v)", req.CreationDateStart, + req.CreationDateEnd) + } + } + query := channeldb.PaymentsQuery{ IndexOffset: req.IndexOffset, MaxPayments: req.MaxPayments, @@ -6367,6 +6377,20 @@ func (r *rpcServer) ListPayments(ctx context.Context, CountTotal: req.CountTotalPayments, } + // Attach the start date if set. + if req.CreationDateStart != 0 { + query.CreationDateStart = time.Unix( + int64(req.CreationDateStart), 0, + ) + } + + // Attach the end date if set. + if req.CreationDateEnd != 0 { + query.CreationDateEnd = time.Unix( + int64(req.CreationDateEnd), 0, + ) + } + // If the maximum number of payments wasn't specified, then we'll // default to return the maximal number of payments representable. if req.MaxPayments == 0 {