routing+channeldb: make MPPayment into an interface

This commit turns `MPPayment` into an interface inside `routing`. Having
this interface gives us the benefit to write more granular unit tests
inside payment lifecycle. As seen from the modified unit tests, several
hacky ways of testing the `SendPayment` method is now replaced by a mock
over `MPPayment`.
This commit is contained in:
yyforyongyu
2023-02-09 12:51:43 +08:00
committed by Olaoluwa Osuntokun
parent c412ab5ccb
commit 34d0e5d4c5
8 changed files with 257 additions and 167 deletions

View File

@@ -491,7 +491,7 @@ func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash,
}
func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
dbMPPayment, error) {
m.Lock()
defer m.Unlock()
@@ -750,10 +750,8 @@ func (m *mockControlTower) FailPayment(phash lntypes.Hash,
}
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
dbMPPayment, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash)
// Type assertion on nil will fail, so we check and return here.
@@ -761,15 +759,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
return nil, args.Error(1)
}
// Make a copy of the payment here to avoid data race.
p := args.Get(0).(*channeldb.MPPayment)
payment := &channeldb.MPPayment{
Info: p.Info,
FailureReason: p.FailureReason,
}
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
copy(payment.HTLCs, p.HTLCs)
payment := args.Get(0).(*mockMPPayment)
return payment, args.Error(1)
}
@@ -794,6 +784,54 @@ func (m *mockControlTower) SubscribeAllPayments() (
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
type mockMPPayment struct {
mock.Mock
}
var _ dbMPPayment = (*mockMPPayment)(nil)
func (m *mockMPPayment) GetState() *channeldb.MPPaymentState {
args := m.Called()
return args.Get(0).(*channeldb.MPPaymentState)
}
func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus {
args := m.Called()
return args.Get(0).(channeldb.PaymentStatus)
}
func (m *mockMPPayment) Terminated() bool {
args := m.Called()
return args.Bool(0)
}
func (m *mockMPPayment) NeedWaitAttempts() (bool, error) {
args := m.Called()
return args.Bool(0), args.Error(1)
}
func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt {
args := m.Called()
return args.Get(0).([]channeldb.HTLCAttempt)
}
func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt {
args := m.Called()
return args.Get(0).([]channeldb.HTLCAttempt)
}
func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason {
args := m.Called()
reason := args.Get(0)
if reason == nil {
return nil
}
return reason.(*channeldb.FailureReason)
}
type mockLink struct {
htlcswitch.ChannelLink
bandwidth lnwire.MilliSatoshi