diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 59f9bd3b9..0cccd5a6c 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -161,9 +161,9 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, t.Fatalf("expected %v, got %v", ogRes.broadcastHeight, diskRes.broadcastHeight) } - if ogRes.htlcIndex != diskRes.htlcIndex { - t.Fatalf("expected %v, got %v", ogRes.htlcIndex, - diskRes.htlcIndex) + if ogRes.htlc.HtlcIndex != diskRes.htlc.HtlcIndex { + t.Fatalf("expected %v, got %v", ogRes.htlc.HtlcIndex, + diskRes.htlc.HtlcIndex) } } @@ -184,9 +184,9 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, t.Fatalf("expected %v, got %v", ogRes.broadcastHeight, diskRes.broadcastHeight) } - if ogRes.payHash != diskRes.payHash { - t.Fatalf("expected %v, got %v", ogRes.payHash, - diskRes.payHash) + if ogRes.htlc.RHash != diskRes.htlc.RHash { + t.Fatalf("expected %v, got %v", ogRes.htlc.RHash, + diskRes.htlc.RHash) } } @@ -265,7 +265,9 @@ func TestContractInsertionRetrieval(t *testing.T) { outputIncubating: true, resolved: true, broadcastHeight: 102, - htlcIndex: 12, + htlc: channeldb.HTLC{ + HtlcIndex: 12, + }, } successResolver := htlcSuccessResolver{ htlcResolution: lnwallet.IncomingHtlcResolution{ @@ -278,8 +280,10 @@ func TestContractInsertionRetrieval(t *testing.T) { outputIncubating: true, resolved: true, broadcastHeight: 109, - payHash: testPreimage, - sweepTx: nil, + htlc: channeldb.HTLC{ + RHash: testPreimage, + }, + sweepTx: nil, } resolvers := []ContractResolver{ &timeoutResolver, @@ -395,7 +399,9 @@ func TestContractResolution(t *testing.T) { outputIncubating: true, resolved: true, broadcastHeight: 192, - htlcIndex: 9912, + htlc: channeldb.HTLC{ + HtlcIndex: 9912, + }, } // First, we'll insert the resolver into the database and ensure that @@ -454,7 +460,9 @@ func TestContractSwapping(t *testing.T) { outputIncubating: true, resolved: true, broadcastHeight: 102, - htlcIndex: 12, + htlc: channeldb.HTLC{ + HtlcIndex: 12, + }, } contestResolver := &htlcOutgoingContestResolver{ htlcTimeoutResolver: timeoutResolver, diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 098bc5f3e..3b25abc75 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -151,6 +151,10 @@ type ChainArbitratorConfig struct { // NotifyClosedChannel is a function closure that the ChainArbitrator // will use to notify the ChannelNotifier about a newly closed channel. NotifyClosedChannel func(wire.OutPoint) + + // OnionProcessor is used to decode onion payloads for on-chain + // resolution. + OnionProcessor OnionProcessor } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 66730ed11..e63a20461 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -501,9 +501,20 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet) error { "resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) for _, resolver := range unresolvedContracts { - if err := c.supplementResolver(resolver, htlcMap); err != nil { - return err + htlcResolver, ok := resolver.(htlcContractResolver) + if !ok { + continue } + + htlcPoint := htlcResolver.HtlcPoint() + htlc, ok := htlcMap[htlcPoint] + if !ok { + return fmt.Errorf( + "htlc resolver %T unavailable", resolver, + ) + } + + htlcResolver.Supplement(*htlc) } c.launchResolvers(unresolvedContracts) @@ -511,89 +522,6 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet) error { return nil } -// supplementResolver takes a resolver as it is restored from the log and fills -// in missing data from the htlcMap. -func (c *ChannelArbitrator) supplementResolver(resolver ContractResolver, - htlcMap map[wire.OutPoint]*channeldb.HTLC) error { - - switch r := resolver.(type) { - - case *htlcSuccessResolver: - return c.supplementSuccessResolver(r, htlcMap) - - case *htlcIncomingContestResolver: - return c.supplementIncomingContestResolver(r, htlcMap) - - case *htlcTimeoutResolver: - return c.supplementTimeoutResolver(r, htlcMap) - - case *htlcOutgoingContestResolver: - return c.supplementTimeoutResolver( - &r.htlcTimeoutResolver, htlcMap, - ) - } - - return nil -} - -// supplementSuccessResolver takes a htlcIncomingContestResolver as it is -// restored from the log and fills in missing data from the htlcMap. -func (c *ChannelArbitrator) supplementIncomingContestResolver( - r *htlcIncomingContestResolver, - htlcMap map[wire.OutPoint]*channeldb.HTLC) error { - - res := r.htlcResolution - htlcPoint := res.HtlcPoint() - htlc, ok := htlcMap[htlcPoint] - if !ok { - return errors.New( - "htlc for incoming contest resolver unavailable", - ) - } - - r.htlcAmt = htlc.Amt - r.circuitKey = channeldb.CircuitKey{ - ChanID: c.cfg.ShortChanID, - HtlcID: htlc.HtlcIndex, - } - - return nil -} - -// supplementSuccessResolver takes a htlcSuccessResolver as it is restored from -// the log and fills in missing data from the htlcMap. -func (c *ChannelArbitrator) supplementSuccessResolver(r *htlcSuccessResolver, - htlcMap map[wire.OutPoint]*channeldb.HTLC) error { - - res := r.htlcResolution - htlcPoint := res.HtlcPoint() - htlc, ok := htlcMap[htlcPoint] - if !ok { - return errors.New( - "htlc for success resolver unavailable", - ) - } - r.htlcAmt = htlc.Amt - return nil -} - -// supplementTimeoutResolver takes a htlcSuccessResolver as it is restored from -// the log and fills in missing data from the htlcMap. -func (c *ChannelArbitrator) supplementTimeoutResolver(r *htlcTimeoutResolver, - htlcMap map[wire.OutPoint]*channeldb.HTLC) error { - - res := r.htlcResolution - htlcPoint := res.HtlcPoint() - htlc, ok := htlcMap[htlcPoint] - if !ok { - return errors.New( - "htlc for timeout resolver unavailable", - ) - } - r.htlcAmt = htlc.Amt - return nil -} - // Report returns htlc reports for the active resolvers. func (c *ChannelArbitrator) Report() []*ContractReport { c.activeResolversLock.RLock() @@ -1224,8 +1152,10 @@ func (c *ChannelArbitrator) checkCommitChainActions(height uint32, // * race condition if adding and we broadcast, etc // * or would make each instance sync? - log.Debugf("ChannelArbitrator(%v): checking chain actions at "+ - "height=%v", c.cfg.ChanPoint, height) + log.Debugf("ChannelArbitrator(%v): checking commit chain actions at "+ + "height=%v, in_htlc_count=%v, out_htlc_count=%v", + c.cfg.ChanPoint, height, + len(htlcs.incomingHTLCs), len(htlcs.outgoingHTLCs)) actionMap := make(ChainActionMap) @@ -1719,6 +1649,8 @@ func (c *ChannelArbitrator) prepContractResolutions( // claim the HTLC (second-level or directly), then add the pre case HtlcClaimAction: for _, htlc := range htlcs { + htlc := htlc + htlcOp := wire.OutPoint{ Hash: commitHash, Index: uint32(htlc.OutputIndex), @@ -1734,8 +1666,7 @@ func (c *ChannelArbitrator) prepContractResolutions( } resolver := newSuccessResolver( - resolution, height, - htlc.RHash, htlc.Amt, resolverCfg, + resolution, height, htlc, resolverCfg, ) htlcResolvers = append(htlcResolvers, resolver) } @@ -1745,6 +1676,8 @@ func (c *ChannelArbitrator) prepContractResolutions( // backwards. case HtlcTimeoutAction: for _, htlc := range htlcs { + htlc := htlc + htlcOp := wire.OutPoint{ Hash: commitHash, Index: uint32(htlc.OutputIndex), @@ -1758,8 +1691,7 @@ func (c *ChannelArbitrator) prepContractResolutions( } resolver := newTimeoutResolver( - resolution, height, htlc.HtlcIndex, - htlc.Amt, resolverCfg, + resolution, height, htlc, resolverCfg, ) htlcResolvers = append(htlcResolvers, resolver) } @@ -1769,6 +1701,8 @@ func (c *ChannelArbitrator) prepContractResolutions( // learn of the pre-image, or let the remote party time out. case HtlcIncomingWatchAction: for _, htlc := range htlcs { + htlc := htlc + htlcOp := wire.OutPoint{ Hash: commitHash, Index: uint32(htlc.OutputIndex), @@ -1785,15 +1719,9 @@ func (c *ChannelArbitrator) prepContractResolutions( continue } - circuitKey := channeldb.CircuitKey{ - HtlcID: htlc.HtlcIndex, - ChanID: c.cfg.ShortChanID, - } - resolver := newIncomingContestResolver( - htlc.RefundTimeout, circuitKey, - resolution, height, htlc.RHash, - htlc.Amt, resolverCfg, + resolution, height, htlc, + resolverCfg, ) htlcResolvers = append(htlcResolvers, resolver) } @@ -1803,6 +1731,8 @@ func (c *ChannelArbitrator) prepContractResolutions( // backwards), or just timeout. case HtlcOutgoingWatchAction: for _, htlc := range htlcs { + htlc := htlc + htlcOp := wire.OutPoint{ Hash: commitHash, Index: uint32(htlc.OutputIndex), @@ -1817,8 +1747,7 @@ func (c *ChannelArbitrator) prepContractResolutions( } resolver := newOutgoingContestResolver( - resolution, height, htlc.HtlcIndex, - htlc.Amt, resolverCfg, + resolution, height, htlc, resolverCfg, ) htlcResolvers = append(htlcResolvers, resolver) } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 85c76479d..599028736 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -308,6 +308,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestC incubateChan <- struct{}{} return nil }, + OnionProcessor: &mockOnionProcessor{}, } // We'll use the resolvedChan to synchronize on call to @@ -858,10 +859,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { resolver) } - // The resolver should have its htlcAmt field populated as it. - if int64(outgoingResolver.htlcAmt) != int64(htlcAmt) { + // The resolver should have its htlc amt field populated as it. + if int64(outgoingResolver.htlc.Amt) != int64(htlcAmt) { t.Fatalf("wrong htlc amount: expected %v, got %v,", - htlcAmt, int64(outgoingResolver.htlcAmt)) + htlcAmt, int64(outgoingResolver.htlc.Amt)) } // htlcOutgoingContestResolver is now active and waiting for the HTLC to diff --git a/contractcourt/contract_resolvers.go b/contractcourt/contract_resolvers.go index 7248bba05..2cda229a8 100644 --- a/contractcourt/contract_resolvers.go +++ b/contractcourt/contract_resolvers.go @@ -4,6 +4,9 @@ import ( "encoding/binary" "errors" "io" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" ) var ( @@ -51,6 +54,18 @@ type ContractResolver interface { Stop() } +// htlcContractResolver is the required interface for htlc resolvers. +type htlcContractResolver interface { + ContractResolver + + // HtlcPoint returns the htlc's outpoint on the commitment tx. + HtlcPoint() wire.OutPoint + + // Supplement adds additional information to the resolver that is + // required before Resolve() is called. + Supplement(htlc channeldb.HTLC) +} + // reportingContractResolver is a ContractResolver that also exposes a report on // the resolution state of the contract. type reportingContractResolver interface { diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 5a40369cb..2e089724f 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -1,16 +1,17 @@ package contractcourt import ( + "bytes" "encoding/binary" "errors" "io" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" - "github.com/lightningnetwork/lnd/lnwire" ) // htlcIncomingContestResolver is a ContractResolver that's able to resolve an @@ -27,28 +28,22 @@ type htlcIncomingContestResolver struct { // successfully. htlcExpiry uint32 - // circuitKey describes the incoming htlc that is being resolved. - circuitKey channeldb.CircuitKey - // htlcSuccessResolver is the inner resolver that may be utilized if we // learn of the preimage. htlcSuccessResolver } // newIncomingContestResolver instantiates a new incoming htlc contest resolver. -func newIncomingContestResolver(htlcExpiry uint32, - circuitKey channeldb.CircuitKey, res lnwallet.IncomingHtlcResolution, - broadcastHeight uint32, payHash lntypes.Hash, - htlcAmt lnwire.MilliSatoshi, - resCfg ResolverConfig) *htlcIncomingContestResolver { +func newIncomingContestResolver( + res lnwallet.IncomingHtlcResolution, broadcastHeight uint32, + htlc channeldb.HTLC, resCfg ResolverConfig) *htlcIncomingContestResolver { success := newSuccessResolver( - res, broadcastHeight, payHash, htlcAmt, resCfg, + res, broadcastHeight, htlc, resCfg, ) return &htlcIncomingContestResolver{ - htlcExpiry: htlcExpiry, - circuitKey: circuitKey, + htlcExpiry: htlc.RefundTimeout, htlcSuccessResolver: *success, } } @@ -72,6 +67,22 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { return nil, nil } + // First try to parse the payload. If that fails, we can stop resolution + // now. + payload, err := h.decodePayload() + if err != nil { + log.Debugf("ChannelArbitrator(%v): cannot decode payload of "+ + "htlc %v", h.ChanPoint, h.HtlcPoint()) + + // If we've locked in an htlc with an invalid payload on our + // commitment tx, we don't need to resolve it. The other party + // will time it out and get their funds back. This situation can + // present itself when we crash before processRemoteAdds in the + // link has ran. + h.resolved = true + return nil, nil + } + // Register for block epochs. After registration, the current height // will be sent on the channel immediately. blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil) @@ -119,7 +130,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { applyPreimage := func(preimage lntypes.Preimage) error { // Sanity check to see if this preimage matches our htlc. At // this point it should never happen that it does not match. - if !preimage.Matches(h.payHash) { + if !preimage.Matches(h.htlc.RHash) { return errors.New("preimage does not match hash") } @@ -185,9 +196,14 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // on-chain. If this HTLC indeed pays to an existing invoice, the // invoice registry will tell us what to do with the HTLC. This is // identical to HTLC resolution in the link. + circuitKey := channeldb.CircuitKey{ + ChanID: h.ShortChanID, + HtlcID: h.htlc.HtlcIndex, + } + event, err := h.Registry.NotifyExitHopHtlc( - h.payHash, h.htlcAmt, h.htlcExpiry, currentHeight, - h.circuitKey, hodlChan, nil, + h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, currentHeight, + circuitKey, hodlChan, payload, ) switch err { case channeldb.ErrInvoiceNotFound: @@ -204,7 +220,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // With the epochs and preimage subscriptions initialized, we'll query // to see if we already know the preimage. - preimage, ok := h.PreimageDB.LookupPreimage(h.payHash) + preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash) if ok { // If we do, then this means we can claim the HTLC! However, // we don't know how to ourselves, so we'll return our inner @@ -222,7 +238,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { case preimage := <-preimageSubscription.WitnessUpdates: // We receive all new preimages, so we need to ignore // all except the preimage we are waiting for. - if !preimage.Matches(h.payHash) { + if !preimage.Matches(h.htlc.RHash) { continue } @@ -268,7 +284,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { func (h *htlcIncomingContestResolver) report() *ContractReport { // No locking needed as these values are read-only. - finalAmt := h.htlcAmt.ToSatoshis() + finalAmt := h.htlc.Amt.ToSatoshis() if h.htlcResolution.SignedSuccessTx != nil { finalAmt = btcutil.Amount( h.htlcResolution.SignedSuccessTx.TxOut[0].Value, @@ -338,6 +354,28 @@ func newIncomingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) ( return h, nil } +// Supplement adds additional information to the resolver that is required +// before Resolve() is called. +// +// NOTE: Part of the htlcContractResolver interface. +func (h *htlcIncomingContestResolver) Supplement(htlc channeldb.HTLC) { + h.htlc = htlc +} + +// decodePayload (re)decodes the hop payload of a received htlc. +func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload, error) { + + onionReader := bytes.NewReader(h.htlc.OnionBlob) + iterator, err := h.OnionProcessor.ReconstructHopIterator( + onionReader, h.htlc.RHash[:], + ) + if err != nil { + return nil, err + } + + return iterator.HopPayload() +} + // A compile time assertion to ensure htlcIncomingContestResolver meets the // ContractResolver interface. -var _ ContractResolver = (*htlcIncomingContestResolver)(nil) +var _ htlcContractResolver = (*htlcIncomingContestResolver)(nil) diff --git a/contractcourt/htlc_incoming_resolver_test.go b/contractcourt/htlc_incoming_resolver_test.go index 1f28f0a50..ab8ad6ec4 100644 --- a/contractcourt/htlc_incoming_resolver_test.go +++ b/contractcourt/htlc_incoming_resolver_test.go @@ -2,9 +2,12 @@ package contractcourt import ( "bytes" + "io" + "io/ioutil" "testing" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnwallet" @@ -21,6 +24,7 @@ var ( testResPreimage = lntypes.Preimage{1, 2, 3} testResHash = testResPreimage.Hash() testResCircuitKey = channeldb.CircuitKey{} + testOnionBlob = []byte{4, 5, 6} ) // TestHtlcIncomingResolverFwdPreimageKnown tests resolution of a forwarded htlc @@ -107,6 +111,12 @@ func TestHtlcIncomingResolverExitSettle(t *testing.T) { } ctx.waitForResult(true) + + if !bytes.Equal( + ctx.onionProcessor.offeredOnionBlob, testOnionBlob, + ) { + t.Fatal("unexpected onion blob") + } } // TestHtlcIncomingResolverExitCancel tests resolution of an exit hop htlc for @@ -168,14 +178,39 @@ func TestHtlcIncomingResolverExitCancelHodl(t *testing.T) { ctx.waitForResult(false) } +type mockHopIterator struct { + hop.Iterator +} + +func (h *mockHopIterator) HopPayload() (*hop.Payload, error) { + return nil, nil +} + +type mockOnionProcessor struct { + offeredOnionBlob []byte +} + +func (o *mockOnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( + hop.Iterator, error) { + + data, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + o.offeredOnionBlob = data + + return &mockHopIterator{}, nil +} + type incomingResolverTestContext struct { - registry *mockRegistry - witnessBeacon *mockWitnessBeacon - resolver *htlcIncomingContestResolver - notifier *mockNotifier - resolveErr chan error - nextResolver ContractResolver - t *testing.T + registry *mockRegistry + witnessBeacon *mockWitnessBeacon + resolver *htlcIncomingContestResolver + notifier *mockNotifier + onionProcessor *mockOnionProcessor + resolveErr chan error + nextResolver ContractResolver + t *testing.T } func newIncomingResolverTestContext(t *testing.T) *incomingResolverTestContext { @@ -189,13 +224,16 @@ func newIncomingResolverTestContext(t *testing.T) *incomingResolverTestContext { notifyChan: make(chan notifyExitHopData, 1), } + onionProcessor := &mockOnionProcessor{} + checkPointChan := make(chan struct{}, 1) chainCfg := ChannelArbitratorConfig{ ChainArbitratorConfig: ChainArbitratorConfig{ - Notifier: notifier, - PreimageDB: witnessBeacon, - Registry: registry, + Notifier: notifier, + PreimageDB: witnessBeacon, + Registry: registry, + OnionProcessor: onionProcessor, }, } @@ -210,17 +248,21 @@ func newIncomingResolverTestContext(t *testing.T) *incomingResolverTestContext { htlcSuccessResolver: htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(cfg), htlcResolution: lnwallet.IncomingHtlcResolution{}, - payHash: testResHash, + htlc: channeldb.HTLC{ + RHash: testResHash, + OnionBlob: testOnionBlob, + }, }, htlcExpiry: testHtlcExpiry, } return &incomingResolverTestContext{ - registry: registry, - witnessBeacon: witnessBeacon, - resolver: resolver, - notifier: notifier, - t: t, + registry: registry, + witnessBeacon: witnessBeacon, + resolver: resolver, + notifier: notifier, + onionProcessor: onionProcessor, + t: t, } } diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 7388b45a8..93a4adfa9 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -5,8 +5,8 @@ import ( "io" "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwallet" - "github.com/lightningnetwork/lnd/lnwire" ) // htlcOutgoingContestResolver is a ContractResolver that's able to resolve an @@ -23,11 +23,11 @@ type htlcOutgoingContestResolver struct { // newOutgoingContestResolver instantiates a new outgoing contested htlc // resolver. func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, - broadcastHeight uint32, htlcIndex uint64, htlcAmt lnwire.MilliSatoshi, + broadcastHeight uint32, htlc channeldb.HTLC, resCfg ResolverConfig) *htlcOutgoingContestResolver { timeout := newTimeoutResolver( - res, broadcastHeight, htlcIndex, htlcAmt, resCfg, + res, broadcastHeight, htlc, resCfg, ) return &htlcOutgoingContestResolver{ @@ -157,7 +157,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { func (h *htlcOutgoingContestResolver) report() *ContractReport { // No locking needed as these values are read-only. - finalAmt := h.htlcAmt.ToSatoshis() + finalAmt := h.htlc.Amt.ToSatoshis() if h.htlcResolution.SignedTimeoutTx != nil { finalAmt = btcutil.Amount( h.htlcResolution.SignedTimeoutTx.TxOut[0].Value, @@ -215,4 +215,4 @@ func newOutgoingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) ( // A compile time assertion to ensure htlcOutgoingContestResolver meets the // ContractResolver interface. -var _ ContractResolver = (*htlcOutgoingContestResolver)(nil) +var _ htlcContractResolver = (*htlcOutgoingContestResolver)(nil) diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index 1f7dbed13..3a6c1ea58 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" @@ -98,6 +99,8 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { preimageDB := newMockWitnessBeacon() + onionProcessor := &mockOnionProcessor{} + chainCfg := ChannelArbitratorConfig{ ChainArbitratorConfig: ChainArbitratorConfig{ Notifier: notifier, @@ -112,6 +115,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { resolutionChan <- msgs[0] return nil }, + OnionProcessor: onionProcessor, }, } @@ -134,6 +138,10 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { htlcTimeoutResolver: htlcTimeoutResolver{ contractResolverKit: *newContractResolverKit(cfg), htlcResolution: outgoingRes, + htlc: channeldb.HTLC{ + RHash: testResHash, + OnionBlob: testOnionBlob, + }, }, } diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 73bf5c5bc..73d370632 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -6,10 +6,9 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/sweep" ) @@ -38,9 +37,6 @@ type htlcSuccessResolver struct { // historical queries to the chain for spends/confirmations. broadcastHeight uint32 - // payHash is the payment hash of the original HTLC extended to us. - payHash lntypes.Hash - // sweepTx will be non-nil if we've already crafted a transaction to // sweep a direct HTLC output. This is only a concern if we're sweeping // from the commitment transaction of the remote party. @@ -48,25 +44,22 @@ type htlcSuccessResolver struct { // TODO(roasbeef): send off to utxobundler sweepTx *wire.MsgTx - // htlcAmt is the original amount of the htlc, not taking into - // account any fees that may have to be paid if it goes on chain. - htlcAmt lnwire.MilliSatoshi + // htlc contains information on the htlc that we are resolving on-chain. + htlc channeldb.HTLC contractResolverKit } // newSuccessResolver instanties a new htlc success resolver. func newSuccessResolver(res lnwallet.IncomingHtlcResolution, - broadcastHeight uint32, payHash lntypes.Hash, - htlcAmt lnwire.MilliSatoshi, + broadcastHeight uint32, htlc channeldb.HTLC, resCfg ResolverConfig) *htlcSuccessResolver { return &htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(resCfg), htlcResolution: res, broadcastHeight: broadcastHeight, - payHash: payHash, - htlcAmt: htlcAmt, + htlc: htlc, } } @@ -114,7 +107,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { if h.sweepTx == nil { log.Infof("%T(%x): crafting sweep tx for "+ "incoming+remote htlc confirmed", h, - h.payHash[:]) + h.htlc.RHash[:]) // Before we can craft out sweeping transaction, we // need to create an input which contains all the items @@ -148,7 +141,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { } log.Infof("%T(%x): crafted sweep tx=%v", h, - h.payHash[:], spew.Sdump(h.sweepTx)) + h.htlc.RHash[:], spew.Sdump(h.sweepTx)) // With the sweep transaction signed, we'll now // Checkpoint our state. @@ -164,7 +157,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { err := h.PublishTx(h.sweepTx) if err != nil { log.Infof("%T(%x): unable to publish tx: %v", - h, h.payHash[:], err) + h, h.htlc.RHash[:], err) return nil, err } @@ -180,7 +173,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { } log.Infof("%T(%x): waiting for sweep tx (txid=%v) to be "+ - "confirmed", h, h.payHash[:], sweepTXID) + "confirmed", h, h.htlc.RHash[:], sweepTXID) select { case _, ok := <-confNtfn.Confirmed: @@ -199,7 +192,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { } log.Infof("%T(%x): broadcasting second-layer transition tx: %v", - h, h.payHash[:], spew.Sdump(h.htlcResolution.SignedSuccessTx)) + h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedSuccessTx)) // We'll now broadcast the second layer transaction so we can kick off // the claiming process. @@ -215,7 +208,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // done so. if !h.outputIncubating { log.Infof("%T(%x): incubating incoming htlc output", - h, h.payHash[:]) + h, h.htlc.RHash[:]) err := h.IncubateOutputs( h.ChanPoint, nil, nil, &h.htlcResolution, @@ -245,7 +238,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { } log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+ - "after csv_delay=%v", h, h.payHash[:], h.htlcResolution.CsvDelay) + "after csv_delay=%v", h, h.htlc.RHash[:], h.htlcResolution.CsvDelay) select { case _, ok := <-spendNtfn.Spend: @@ -298,7 +291,7 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error { if err := binary.Write(w, endian, h.broadcastHeight); err != nil { return err } - if _, err := w.Write(h.payHash[:]); err != nil { + if _, err := w.Write(h.htlc.RHash[:]); err != nil { return err } @@ -331,13 +324,28 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) ( if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { return nil, err } - if _, err := io.ReadFull(r, h.payHash[:]); err != nil { + if _, err := io.ReadFull(r, h.htlc.RHash[:]); err != nil { return nil, err } return h, nil } +// Supplement adds additional information to the resolver that is required +// before Resolve() is called. +// +// NOTE: Part of the htlcContractResolver interface. +func (h *htlcSuccessResolver) Supplement(htlc channeldb.HTLC) { + h.htlc = htlc +} + +// HtlcPoint returns the htlc's outpoint on the commitment tx. +// +// NOTE: Part of the htlcContractResolver interface. +func (h *htlcSuccessResolver) HtlcPoint() wire.OutPoint { + return h.htlcResolution.HtlcPoint() +} + // A compile time assertion to ensure htlcSuccessResolver meets the // ContractResolver interface. -var _ ContractResolver = (*htlcSuccessResolver)(nil) +var _ htlcContractResolver = (*htlcSuccessResolver)(nil) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 248cfa36b..2469f5318 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" @@ -40,29 +41,22 @@ type htlcTimeoutResolver struct { // TODO(roasbeef): wrap above into definite resolution embedding? broadcastHeight uint32 - // htlcIndex is the index of this HTLC within the trace of the - // additional commitment state machine. - htlcIndex uint64 - - // htlcAmt is the original amount of the htlc, not taking into - // account any fees that may have to be paid if it goes on chain. - htlcAmt lnwire.MilliSatoshi + // htlc contains information on the htlc that we are resolving on-chain. + htlc channeldb.HTLC contractResolverKit } // newTimeoutResolver instantiates a new timeout htlc resolver. func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution, - broadcastHeight uint32, htlcIndex uint64, - htlcAmt lnwire.MilliSatoshi, + broadcastHeight uint32, htlc channeldb.HTLC, resCfg ResolverConfig) *htlcTimeoutResolver { return &htlcTimeoutResolver{ contractResolverKit: *newContractResolverKit(resCfg), htlcResolution: res, broadcastHeight: broadcastHeight, - htlcIndex: htlcIndex, - htlcAmt: htlcAmt, + htlc: htlc, } } @@ -157,7 +151,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( // resolved, then exit. if err := h.DeliverResolutionMsg(ResolutionMsg{ SourceChan: h.ShortChanID, - HtlcIndex: h.htlcIndex, + HtlcIndex: h.htlc.HtlcIndex, PreImage: &pre, }); err != nil { return nil, err @@ -352,7 +346,7 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { failureMsg := &lnwire.FailPermanentChannelFailure{} if err := h.DeliverResolutionMsg(ResolutionMsg{ SourceChan: h.ShortChanID, - HtlcIndex: h.htlcIndex, + HtlcIndex: h.htlc.HtlcIndex, Failure: failureMsg, }); err != nil { return nil, err @@ -414,7 +408,7 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error { return err } - if err := binary.Write(w, endian, h.htlcIndex); err != nil { + if err := binary.Write(w, endian, h.htlc.HtlcIndex); err != nil { return err } @@ -449,13 +443,28 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) ( return nil, err } - if err := binary.Read(r, endian, &h.htlcIndex); err != nil { + if err := binary.Read(r, endian, &h.htlc.HtlcIndex); err != nil { return nil, err } return h, nil } +// Supplement adds additional information to the resolver that is required +// before Resolve() is called. +// +// NOTE: Part of the htlcContractResolver interface. +func (h *htlcTimeoutResolver) Supplement(htlc channeldb.HTLC) { + h.htlc = htlc +} + +// HtlcPoint returns the htlc's outpoint on the commitment tx. +// +// NOTE: Part of the htlcContractResolver interface. +func (h *htlcTimeoutResolver) HtlcPoint() wire.OutPoint { + return h.htlcResolution.HtlcPoint() +} + // A compile time assertion to ensure htlcTimeoutResolver meets the // ContractResolver interface. -var _ ContractResolver = (*htlcTimeoutResolver)(nil) +var _ htlcContractResolver = (*htlcTimeoutResolver)(nil) diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index c73a23a03..a88477856 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -1,7 +1,10 @@ package contractcourt import ( + "io" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -26,3 +29,10 @@ type Registry interface { // HodlUnsubscribeAll unsubscribes from all hodl events. HodlUnsubscribeAll(subscriber chan<- interface{}) } + +// OnionProcessor is an interface used to decode onion blobs. +type OnionProcessor interface { + // ReconstructHopIterator attempts to decode a valid sphinx packet from + // the passed io.Reader instance. + ReconstructHopIterator(r io.Reader, rHash []byte) (hop.Iterator, error) +} diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index a11ce632c..5c8afed2b 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -184,6 +184,30 @@ func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte, return makeSphinxHopIterator(onionPkt, sphinxPacket), lnwire.CodeNone } +// ReconstructHopIterator attempts to decode a valid sphinx packet from the passed io.Reader +// instance using the rHash as the associated data when checking the relevant +// MACs during the decoding process. +func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( + Iterator, error) { + + onionPkt := &sphinx.OnionPacket{} + if err := onionPkt.Decode(r); err != nil { + return nil, err + } + + // Attempt to process the Sphinx packet. We include the payment hash of + // the HTLC as it's authenticated within the Sphinx packet itself as + // associated data in order to thwart attempts a replay attacks. In the + // case of a replay, an attacker is *forced* to use the same payment + // hash twice, thereby losing their money entirely. + sphinxPacket, err := p.router.ReconstructOnionPacket(onionPkt, rHash) + if err != nil { + return nil, err + } + + return makeSphinxHopIterator(onionPkt, sphinxPacket), nil +} + // DecodeHopIteratorRequest encapsulates all date necessary to process an onion // packet, perform sphinx replay detection, and schedule the entry for garbage // collection. diff --git a/server.go b/server.go index 7047721c1..ccdd13ae0 100644 --- a/server.go +++ b/server.go @@ -906,6 +906,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, Sweeper: s.sweeper, Registry: s.invoices, NotifyClosedChannel: s.channelNotifier.NotifyClosedChannelEvent, + OnionProcessor: s.sphinx, }, chanDB) s.breachArbiter = newBreachArbiter(&BreachConfig{