diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 43549a261..789558a67 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -76,36 +76,6 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { default: } - // We'll check the current height, if the HTLC has already expired, - // then we'll morph immediately into a resolver that can sweep the - // HTLC. - // - // TODO(roasbeef): use grace period instead? - _, currentHeight, err := h.ChainIO.GetBestBlock() - if err != nil { - return nil, err - } - - // If the current height is >= expiry-1, then a spend will be valid to - // be included in the next block, and we can immediately return the - // resolver. - // - // TODO(joostjager): Statement above may not be valid. For CLTV locks, - // the expiry value is the last _invalid_ block. The likely reason that - // this does not create a problem, is that utxonursery is checking the - // expiry again (in the proper way). Same holds for minus one operation - // below. - // - // Source: - // https://github.com/btcsuite/btcd/blob/991d32e72fe84d5fbf9c47cd604d793a0cd3a072/blockchain/validate.go#L154 - if uint32(currentHeight) >= h.htlcResolution.Expiry-1 { - log.Infof("%T(%v): HTLC has expired (height=%v, expiry=%v), "+ - "transforming into timeout resolver", h, - h.htlcResolution.ClaimOutpoint, currentHeight, - h.htlcResolution.Expiry) - return &h.htlcTimeoutResolver, nil - } - // If we reach this point, then we can't fully act yet, so we'll await // either of our signals triggering: the HTLC expires, or we learn of // the preimage. @@ -125,9 +95,18 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { return nil, errResolverShuttingDown } - // If this new height expires the HTLC, then we can - // exit early and create a resolver that's capable of - // handling the time locked output. + // If the current height is >= expiry-1, then a timeout + // path spend will be valid to be included in the next + // block, and we can immediately return the resolver. + // + // TODO(joostjager): Statement above may not be valid. + // For CLTV locks, the expiry value is the last + // _invalid_ block. The likely reason that this does not + // create a problem, is that utxonursery is checking the + // expiry again (in the proper way). + // + // Source: + // https://github.com/btcsuite/btcd/blob/991d32e72fe84d5fbf9c47cd604d793a0cd3a072/blockchain/validate.go#L154 newHeight := uint32(newBlock.Height) if newHeight >= h.htlcResolution.Expiry-1 { log.Infof("%T(%v): HTLC has expired "+ diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go new file mode 100644 index 000000000..244b438d3 --- /dev/null +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -0,0 +1,187 @@ +package contractcourt + +import ( + "fmt" + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet" +) + +const ( + outgoingContestHtlcExpiry = 110 +) + +// TestHtlcOutgoingResolverTimeout tests resolution of an offered htlc that +// timed out. +func TestHtlcOutgoingResolverTimeout(t *testing.T) { + t.Parallel() + defer timeout(t)() + + // Setup the resolver with our test resolution. + ctx := newOutgoingResolverTestContext(t) + + // Start the resolution process in a goroutine. + ctx.resolve() + + // Notify arrival of the block after which the timeout path of the htlc + // unlocks. + ctx.notifyEpoch(outgoingContestHtlcExpiry - 1) + + // Assert that the resolver finishes without error and transforms in a + // timeout resolver. + ctx.waitForResult(true) +} + +// TestHtlcOutgoingResolverRemoteClaim tests resolution of an offered htlc that +// is claimed by the remote party. +func TestHtlcOutgoingResolverRemoteClaim(t *testing.T) { + t.Parallel() + defer timeout(t)() + + // Setup the resolver with our test resolution and start the resolution + // process. + ctx := newOutgoingResolverTestContext(t) + ctx.resolve() + + // The remote party sweeps the htlc. Notify our resolver of this event. + preimage := lntypes.Preimage{} + ctx.notifier.spendChan <- &chainntnfs.SpendDetail{ + SpendingTx: &wire.MsgTx{ + TxIn: []*wire.TxIn{ + { + Witness: [][]byte{ + {0}, {1}, {2}, preimage[:], + }, + }, + }, + }, + } + + // We expect the extracted preimage to be added to the witness beacon. + <-ctx.preimageDB.newPreimages + + // We also expect a resolution message to the incoming side of the + // circuit. + <-ctx.resolutionChan + + // Assert that the resolver finishes without error. + ctx.waitForResult(false) +} + +type resolveResult struct { + err error + nextResolver ContractResolver +} + +type outgoingResolverTestContext struct { + resolver *htlcOutgoingContestResolver + notifier *mockNotifier + preimageDB *mockWitnessBeacon + resolverResultChan chan resolveResult + resolutionChan chan ResolutionMsg + t *testing.T +} + +func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { + notifier := &mockNotifier{ + epochChan: make(chan *chainntnfs.BlockEpoch), + spendChan: make(chan *chainntnfs.SpendDetail), + confChan: make(chan *chainntnfs.TxConfirmation), + } + + checkPointChan := make(chan struct{}, 1) + resolutionChan := make(chan ResolutionMsg, 1) + + preimageDB := newMockWitnessBeacon() + + chainCfg := ChannelArbitratorConfig{ + ChainArbitratorConfig: ChainArbitratorConfig{ + Notifier: notifier, + PreimageDB: preimageDB, + DeliverResolutionMsg: func(msgs ...ResolutionMsg) error { + if len(msgs) != 1 { + return fmt.Errorf("expected 1 "+ + "resolution msg, instead got %v", + len(msgs)) + } + + resolutionChan <- msgs[0] + return nil + }, + }, + } + + outgoingRes := lnwallet.OutgoingHtlcResolution{ + Expiry: outgoingContestHtlcExpiry, + SweepSignDesc: input.SignDescriptor{ + Output: &wire.TxOut{}, + }, + } + + resolver := &htlcOutgoingContestResolver{ + htlcTimeoutResolver: htlcTimeoutResolver{ + ResolverKit: ResolverKit{ + ChannelArbitratorConfig: chainCfg, + Checkpoint: func(_ ContractResolver) error { + checkPointChan <- struct{}{} + return nil + }, + }, + htlcResolution: outgoingRes, + }, + } + + return &outgoingResolverTestContext{ + resolver: resolver, + notifier: notifier, + preimageDB: preimageDB, + resolutionChan: resolutionChan, + t: t, + } +} + +func (i *outgoingResolverTestContext) resolve() { + // Start resolver. + i.resolverResultChan = make(chan resolveResult, 1) + go func() { + nextResolver, err := i.resolver.Resolve() + i.resolverResultChan <- resolveResult{ + nextResolver: nextResolver, + err: err, + } + }() + + // Notify initial block height. + i.notifyEpoch(testInitialBlockHeight) +} + +func (i *outgoingResolverTestContext) notifyEpoch(height int32) { + i.notifier.epochChan <- &chainntnfs.BlockEpoch{ + Height: height, + } +} + +func (i *outgoingResolverTestContext) waitForResult(expectTimeoutRes bool) { + i.t.Helper() + + result := <-i.resolverResultChan + if result.err != nil { + i.t.Fatal(result.err) + } + + if !expectTimeoutRes { + if result.nextResolver != nil { + i.t.Fatal("expected no next resolver") + } + return + } + + _, ok := result.nextResolver.(*htlcTimeoutResolver) + if !ok { + i.t.Fatal("expected htlcTimeoutResolver") + } +}