cnct/test: extend mockWitnessBeacon

This commit is contained in:
Joost Jager
2019-04-16 10:22:04 +02:00
parent 064e8492de
commit 16ff4e3ffa

View File

@@ -29,8 +29,16 @@ func (m *mockSigner) ComputeInputScript(tx *wire.MsgTx,
type mockWitnessBeacon struct { type mockWitnessBeacon struct {
preImageUpdates chan lntypes.Preimage preImageUpdates chan lntypes.Preimage
newPreimages chan []lntypes.Preimage newPreimages chan []lntypes.Preimage
lookupPreimage map[lntypes.Hash]lntypes.Preimage
}
func newMockWitnessBeacon() *mockWitnessBeacon {
return &mockWitnessBeacon{
preImageUpdates: make(chan lntypes.Preimage, 1),
newPreimages: make(chan []lntypes.Preimage),
lookupPreimage: make(map[lntypes.Hash]lntypes.Preimage),
}
} }
func (m *mockWitnessBeacon) SubscribeUpdates() *WitnessSubscription { func (m *mockWitnessBeacon) SubscribeUpdates() *WitnessSubscription {
@@ -41,8 +49,12 @@ func (m *mockWitnessBeacon) SubscribeUpdates() *WitnessSubscription {
} }
func (m *mockWitnessBeacon) LookupPreimage(payhash lntypes.Hash) (lntypes.Preimage, bool) { func (m *mockWitnessBeacon) LookupPreimage(payhash lntypes.Hash) (lntypes.Preimage, bool) {
preimage, ok := m.lookupPreimage[payhash]
if !ok {
return lntypes.Preimage{}, false return lntypes.Preimage{}, false
} }
return preimage, true
}
func (m *mockWitnessBeacon) AddPreimages(preimages ...lntypes.Preimage) error { func (m *mockWitnessBeacon) AddPreimages(preimages ...lntypes.Preimage) error {
m.newPreimages <- preimages m.newPreimages <- preimages
@@ -190,10 +202,7 @@ func TestHtlcTimeoutResolver(t *testing.T) {
spendChan: make(chan *chainntnfs.SpendDetail), spendChan: make(chan *chainntnfs.SpendDetail),
confChan: make(chan *chainntnfs.TxConfirmation), confChan: make(chan *chainntnfs.TxConfirmation),
} }
witnessBeacon := &mockWitnessBeacon{ witnessBeacon := newMockWitnessBeacon()
preImageUpdates: make(chan lntypes.Preimage, 1),
newPreimages: make(chan []lntypes.Preimage),
}
for _, testCase := range testCases { for _, testCase := range testCases {
t.Logf("Running test case: %v", testCase.name) t.Logf("Running test case: %v", testCase.name)