diff --git a/channeldb/channel.go b/channeldb/channel.go index f91c914d8..9d7d01bae 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -497,6 +497,7 @@ func (c *OpenChannel) RefreshShortChanID() error { } c.ShortChannelID = sid + c.Packager = NewChannelPackager(sid) return nil } @@ -665,6 +666,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.IsPending = false c.ShortChannelID = openLoc + c.Packager = NewChannelPackager(openLoc) return nil } @@ -1474,6 +1476,9 @@ func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { // processed, and returns their deserialized log updates in map indexed by the // remote commitment height at which the updates were locked in. func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + var fwdPkgs []*FwdPkg if err := c.Db.View(func(tx *bolt.Tx) error { var err error @@ -1489,6 +1494,9 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { // SetFwdFilter atomically sets the forwarding filter for the forwarding package // identified by `height`. func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }) @@ -1499,6 +1507,9 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { // // NOTE: This method should only be called on packages marked FwdStateCompleted. func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.RemovePkg(tx, height) }) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index f7f331af7..efd36abbd 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -898,6 +898,16 @@ func TestRefreshShortChanID(t *testing.T) { "updated before refreshing short_chan_id") } + // Now that the receiver's short channel id has been updated, check to + // ensure that the channel packager's source has been updated as well. + // This ensures that the packager will read and write to buckets + // corresponding to the new short chan id, instead of the prior. + if state.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + state.Packager.(*ChannelPackager).source) + } + // Now, refresh the short channel ID of the pending channel. err = pendingChannel.RefreshShortChanID() if err != nil { @@ -911,4 +921,14 @@ func TestRefreshShortChanID(t *testing.T) { "refreshed: want %v, got %v", state.ShortChanID(), pendingChannel.ShortChanID()) } + + // Check to ensure that the _other_ OpenChannel channel packager's + // source has also been updated after the refresh. This ensures that the + // other packagers will read and write to buckets corresponding to the + // updated short chan id. + if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + pendingChannel.Packager.(*ChannelPackager).source) + } } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 5fdb72c0e..9fc3023a5 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -301,6 +301,7 @@ func (c *ChainArbitrator) resolveContract(chanPoint wire.OutPoint, if ok { chainWatcher.Stop() } + delete(c.activeWatchers, chanPoint) c.Unlock() return nil diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 0ee683131..44e40ce34 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -231,6 +231,7 @@ func (c *ChannelArbitrator) Start() error { // machine can act accordingly. c.state, err = c.log.CurrentState() if err != nil { + c.cfg.BlockEpochs.Cancel() return err } @@ -239,6 +240,7 @@ func (c *ChannelArbitrator) Start() error { _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() if err != nil { + c.cfg.BlockEpochs.Cancel() return err } @@ -249,6 +251,7 @@ func (c *ChannelArbitrator) Start() error { uint32(bestHeight), chainTrigger, nil, ) if err != nil { + c.cfg.BlockEpochs.Cancel() return err } @@ -262,6 +265,7 @@ func (c *ChannelArbitrator) Start() error { // relaunch all contract resolvers. unresolvedContracts, err = c.log.FetchUnresolvedContracts() if err != nil { + c.cfg.BlockEpochs.Cancel() return err } @@ -301,8 +305,6 @@ func (c *ChannelArbitrator) Stop() error { close(c.quit) c.wg.Wait() - c.cfg.BlockEpochs.Cancel() - return nil } @@ -1289,7 +1291,10 @@ func (c *ChannelArbitrator) UpdateContractSignals(newSignals *ContractSignals) { func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // TODO(roasbeef): tell top chain arb we're done - defer c.wg.Done() + defer func() { + c.cfg.BlockEpochs.Cancel() + c.wg.Done() + }() for { select { diff --git a/docs/debugging_lnd.md b/docs/debugging_lnd.md new file mode 100644 index 000000000..3790b81e1 --- /dev/null +++ b/docs/debugging_lnd.md @@ -0,0 +1,47 @@ +# Table of Contents +1. [Overview](#overview) +1. [Debug Logging](#debug-logging) +1. [Capturing pprof data with `lnd`](#capturing-pprof-data-with-lnd) + +## Overview + +`lnd` ships with a few useful features for debugging, such as a built-in +profiler and tunable logging levels. If you need to submit a bug report +for `lnd`, it may be helpful to capture debug logging and performance +data ahead of time. + +## Debug Logging + +You can enable debug logging in `lnd` by passing the `--debuglevel` flag. For +example, to increase the log level from `info` to `debug`: + +``` +$ lnd --debuglevel=debug +``` + +You may also specify logging per-subsystem, like this: + +``` +$ lnd --debuglevel==,=,... +``` + +## Capturing pprof data with `lnd` + +`lnd` has a built-in feature which allows you to capture profiling data at +runtime using [pprof](https://golang.org/pkg/runtime/pprof/), a profiler for +Go. The profiler has negligible performance overhead during normal operations +(unless you have explictly enabled CPU profiling). + +To enable this ability, start `lnd` with the `--profile` option using a free port. + +``` +$ lnd --profile=9736 +``` + +Now, with `lnd` running, you can use the pprof endpoint on port 9736 to collect +runtime profiling data. You can fetch this data using `curl` like so: + +``` +$ curl http://localhost:9736/debug/pprof/goroutine?debug=1 +... +``` diff --git a/fundingmanager.go b/fundingmanager.go index df68f0fe5..b06464fbd 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -442,10 +442,11 @@ var ( // of being opened. channelOpeningStateBucket = []byte("channelOpeningState") - // ErrChannelNotFound is returned when we are looking for a specific - // channel opening state in the FundingManager's internal database, but - // the channel in question is not considered being in an opening state. - ErrChannelNotFound = fmt.Errorf("channel not found in db") + // ErrChannelNotFound is an error returned when a channel is not known + // to us. In this case of the fundingManager, this error is returned + // when the channel in question is not considered being in an opening + // state. + ErrChannelNotFound = fmt.Errorf("channel not found") ) // newFundingManager creates and initializes a new instance of the @@ -1616,9 +1617,7 @@ func (f *fundingManager) handleFundingSigned(fmsg *fundingSignedMsg) { fndgLog.Errorf("failed creating lnChannel: %v", err) return } - defer func() { - lnChannel.Stop() - }() + defer lnChannel.Stop() err = f.sendFundingLocked(completeChan, lnChannel, shortChanID) if err != nil { @@ -1879,9 +1878,7 @@ func (f *fundingManager) handleFundingConfirmation(completeChan *channeldb.OpenC if err != nil { return err } - defer func() { - lnChannel.Stop() - }() + defer lnChannel.Stop() chanID := lnwire.NewChanIDFromOutPoint(&completeChan.FundingOutpoint) @@ -2224,6 +2221,7 @@ func (f *fundingManager) handleFundingLocked(fmsg *fundingLockedMsg) { err = channel.InitNextRevocation(fmsg.msg.NextPerCommitmentPoint) if err != nil { fndgLog.Errorf("unable to insert next commitment point: %v", err) + channel.Stop() return } @@ -2249,6 +2247,7 @@ func (f *fundingManager) handleFundingLocked(fmsg *fundingLockedMsg) { peer, err := f.cfg.FindPeer(fmsg.peerAddress.IdentityKey) if err != nil { fndgLog.Errorf("Unable to find peer: %v", err) + channel.Stop() return } newChanDone := make(chan struct{}) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f7d3b5c88..1d649d732 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1883,7 +1883,7 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error { } } - link.Stop() + go link.Stop() return nil } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 0633e6b0e..f1311c620 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -651,9 +651,14 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{ - // TODO(roasbeef): later alternate encoding types - EncodingType: EncodingSortedPlain, + req := QueryShortChanIDs{} + + // With a 50/50 change, we'll either use zlib encoding, + // or regular encoding. + if r.Int31()%2 == 0 { + req.EncodingType = EncodingSortedZlib + } else { + req.EncodingType = EncodingSortedPlain } if _, err := rand.Read(req.ChainHash[:]); err != nil { @@ -687,8 +692,13 @@ func TestLightningWireProtocol(t *testing.T) { req.Complete = uint8(r.Int31n(2)) - // TODO(roasbeef): later alternate encoding types - req.EncodingType = EncodingSortedPlain + // With a 50/50 change, we'll either use zlib encoding, + // or regular encoding. + if r.Int31()%2 == 0 { + req.EncodingType = EncodingSortedZlib + } else { + req.EncodingType = EncodingSortedPlain + } numChanIDs := rand.Int31n(5000) diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 1f4c1d356..fd959c433 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -2,9 +2,11 @@ package lnwire import ( "bytes" + "compress/zlib" "fmt" "io" "sort" + "sync" "github.com/roasbeef/btcd/chaincfg/chainhash" ) @@ -20,10 +22,24 @@ const ( // encoded using the regular encoding, in a sorted order. EncodingSortedPlain ShortChanIDEncoding = 0 - // TODO(roasbeef): list max number of short chan id's that are able to - // use + // EncodingSortedZlib signals that the set of short channel ID's is + // encoded by first sorting the set of channel ID's, as then + // compressing them using zlib. + EncodingSortedZlib ShortChanIDEncoding = 1 ) +const ( + // maxZlibBufSize is the max number of bytes that we'll accept from a + // zlib decoding instance. We do this in order to limit the total + // amount of memory allocated during a decoding instance. + maxZlibBufSize = 67413630 +) + +// zlibDecodeMtx is a package level mutex that we'll use in order to ensure +// that we'll only attempt a single zlib decoding instance at a time. This +// allows us to also further bound our memory usage. +var zlibDecodeMtx sync.Mutex + // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we // came across an unknown short channel ID encoding, and therefore were unable // to continue parsing. @@ -144,6 +160,71 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err return encodingType, shortChanIDs, nil + // In this encoding, we'll use zlib to decode the compressed payload. + // However, we'll pay attention to ensure that we don't open our selves + // up to a memory exhaustion attack. + case EncodingSortedZlib: + // We'll obtain an ultimately release the zlib decode mutex. + // This guards us against allocating too much memory to decode + // each instance from concurrent peers. + zlibDecodeMtx.Lock() + defer zlibDecodeMtx.Unlock() + + // Before we start to decode, we'll create a limit reader over + // the current reader. This will ensure that we can control how + // much memory we're allocating during the decoding process. + limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ + R: bytes.NewReader(queryBody), + N: maxZlibBufSize, + }) + if err != nil { + return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) + } + + var ( + shortChanIDs []ShortChannelID + lastChanID ShortChannelID + ) + for { + // We'll now attempt to read the next short channel ID + // encoded in the payload. + var cid ShortChannelID + err := readElements(limitedDecompressor, &cid) + + switch { + // If we get an EOF error, then that either means we've + // read all that's contained in the buffer, or have hit + // our limit on the number of bytes we'll read. In + // either case, we'll return what we have so far. + case err == io.ErrUnexpectedEOF || err == io.EOF: + return encodingType, shortChanIDs, nil + + // Otherwise, we hit some other sort of error, possibly + // an invalid payload, so we'll exit early with the + // error. + case err != nil: + return 0, nil, fmt.Errorf("unable to "+ + "deflate next short chan "+ + "ID: %v", err) + } + + // We successfully read the next ID, so well collect + // that in the set of final ID's to return. + shortChanIDs = append(shortChanIDs, cid) + + // Finally, we'll ensure that this short chan ID is + // greater than the last one. This is a requirement + // within the encoding, and if violated can aide us in + // detecting malicious payloads. + if cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, fmt.Errorf("current sid of %v "+ + "isn't greater than last sid of %v", cid, + lastChanID) + } + + lastChanID = cid + } + default: // If we've been sent an encoding type that we don't know of, // then we'll return a parsing error as we can't continue if @@ -173,6 +254,13 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, shortChanIDs []ShortChannelID) error { + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + switch encodingType { // In this encoding, we'll simply write a sorted array of encoded short @@ -192,13 +280,6 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return err } - // Next, we'll ensure that the set of short channel ID's is - // properly sorted in place. - sort.Slice(shortChanIDs, func(i, j int) bool { - return shortChanIDs[i].ToUint64() < - shortChanIDs[j].ToUint64() - }) - // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { @@ -210,6 +291,54 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, return nil + // For this encoding we'll first write out a serialized version of all + // the channel ID's into a buffer, then zlib encode that. The final + // payload is what we'll write out to the passed io.Writer. + // + // TODO(roasbeef): assumes the caller knows the proper chunk size to + // pass to avoid bin-packing here + case EncodingSortedZlib: + // We'll make a new buffer, then wrap that with a zlib writer + // so we can write directly to the buffer and encode in a + // streaming manner. + var buf bytes.Buffer + zlibWriter := zlib.NewWriter(&buf) + + // Next, we'll write out all the channel ID's directly into the + // zlib writer, which will do compressing on the fly. + for _, chanID := range shortChanIDs { + err := writeElements(zlibWriter, chanID) + if err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + // Now that we've written all the elements, we'll ensure the + // compressed stream is written to the underlying buffer. + if err := zlibWriter.Close(); err != nil { + return fmt.Errorf("unable to finalize "+ + "compression: %v", err) + } + + // Now that we have all the items compressed, we can compute + // what the total payload size will be. We add one to account + // for the byte to encode the type. + compressedPayload := buf.Bytes() + numBytesBody := len(compressedPayload) + 1 + + // Finally, we can write out the number of bytes, the + // compression type, and finally the buffer itself. + if err := writeElements(w, uint16(numBytesBody)); err != nil { + return err + } + if err := writeElements(w, encodingType); err != nil { + return err + } + + _, err := w.Write(compressedPayload) + return err + default: // If we're trying to encode with an encoding type that we // don't know of, then we'll return a parsing error as we can't diff --git a/peer.go b/peer.go index 63e833637..82b077f67 100644 --- a/peer.go +++ b/peer.go @@ -314,7 +314,6 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error { p.server.cc.signer, p.server.witnessBeacon, dbChan, ) if err != nil { - lnChan.Stop() return err } @@ -1540,7 +1539,13 @@ out: // closure process. chanCloser, err := p.fetchActiveChanCloser(closeMsg.cid) if err != nil { - peerLog.Errorf("unable to respond to remote "+ + // If the channel is not known to us, we'll + // simply ignore this message. + if err == ErrChannelNotFound { + continue + } + + peerLog.Errorf("Unable to respond to remote "+ "close msg: %v", err) errMsg := &lnwire.Error{ @@ -1618,8 +1623,7 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e channel, ok := p.activeChannels[chanID] p.activeChanMtx.RUnlock() if !ok { - return nil, fmt.Errorf("unable to close channel, "+ - "ChannelID(%v) is unknown", chanID) + return nil, ErrChannelNotFound } // We'll attempt to look up the matching state machine, if we can't