mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-12 09:52:14 +02:00
Merge pull request #2474 from cfromknecht/read-and-write-pools
lnpeer+brontide: reduce memory footprint using read/write pools for message encode/decode
This commit is contained in:
commit
a6ba965bc4
@ -104,13 +104,34 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
|||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadNextMessage uses the connection in a message-oriented instructing it to
|
// ReadNextMessage uses the connection in a message-oriented manner, instructing
|
||||||
// read the next _full_ message with the brontide stream. This function will
|
// it to read the next _full_ message with the brontide stream. This function
|
||||||
// block until the read succeeds.
|
// will block until the read of the header and body succeeds.
|
||||||
|
//
|
||||||
|
// NOTE: This method SHOULD NOT be used in the case that the connection may be
|
||||||
|
// adversarial and induce long delays. If the caller needs to set read deadlines
|
||||||
|
// appropriately, it is preferred that they use the split ReadNextHeader and
|
||||||
|
// ReadNextBody methods so that the deadlines can be set appropriately on each.
|
||||||
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
||||||
return c.noise.ReadMessage(c.conn)
|
return c.noise.ReadMessage(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadNextHeader uses the connection to read the next header from the brontide
|
||||||
|
// stream. This function will block until the read of the header succeeds and
|
||||||
|
// return the packet length (including MAC overhead) that is expected from the
|
||||||
|
// subsequent call to ReadNextBody.
|
||||||
|
func (c *Conn) ReadNextHeader() (uint32, error) {
|
||||||
|
return c.noise.ReadHeader(c.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadNextBody uses the connection to read the next message body from the
|
||||||
|
// brontide stream. This function will block until the read of the body succeeds
|
||||||
|
// and return the decrypted payload. The provided buffer MUST be the packet
|
||||||
|
// length returned by the preceding call to ReadNextHeader.
|
||||||
|
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
||||||
|
return c.noise.ReadBody(c.conn, buf)
|
||||||
|
}
|
||||||
|
|
||||||
// Read reads data from the connection. Read can be made to time out and
|
// Read reads data from the connection. Read can be made to time out and
|
||||||
// return an Error with Timeout() == true after a fixed time limit; see
|
// return an Error with Timeout() == true after a fixed time limit; see
|
||||||
// SetDeadline and SetReadDeadline.
|
// SetDeadline and SetReadDeadline.
|
||||||
|
@ -8,15 +8,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/buffer"
|
|
||||||
"github.com/lightningnetwork/lnd/pool"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -60,14 +57,6 @@ var (
|
|||||||
ephemeralGen = func() (*btcec.PrivateKey, error) {
|
ephemeralGen = func() (*btcec.PrivateKey, error) {
|
||||||
return btcec.NewPrivateKey(btcec.S256())
|
return btcec.NewPrivateKey(btcec.S256())
|
||||||
}
|
}
|
||||||
|
|
||||||
// readBufferPool is a singleton instance of a buffer pool, used to
|
|
||||||
// conserve memory allocations due to read buffers across the entire
|
|
||||||
// brontide package.
|
|
||||||
readBufferPool = pool.NewReadBuffer(
|
|
||||||
pool.DefaultReadBufferGCInterval,
|
|
||||||
pool.DefaultReadBufferExpiryInterval,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO(roasbeef): free buffer pool?
|
// TODO(roasbeef): free buffer pool?
|
||||||
@ -378,15 +367,6 @@ type Machine struct {
|
|||||||
// next ciphertext header from the wire. The header is a 2 byte length
|
// next ciphertext header from the wire. The header is a 2 byte length
|
||||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
nextCipherHeader [lengthHeaderSize + macSize]byte
|
||||||
|
|
||||||
// nextCipherText is a static buffer that we'll use to read in the
|
|
||||||
// bytes of the next cipher text message. As all messages in the
|
|
||||||
// protocol MUST be below 65KB plus our macSize, this will be
|
|
||||||
// sufficient to buffer all messages from the socket when we need to
|
|
||||||
// read the next one. Having a fixed buffer that's re-used also means
|
|
||||||
// that we save on allocations as we don't need to create a new one
|
|
||||||
// each time.
|
|
||||||
nextCipherText *buffer.Read
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
||||||
@ -738,43 +718,59 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
|||||||
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
||||||
// the case of an authentication error, a non-nil error is returned.
|
// the case of an authentication error, a non-nil error is returned.
|
||||||
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {
|
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {
|
||||||
if _, err := io.ReadFull(r, b.nextCipherHeader[:]); err != nil {
|
pktLen, err := b.ReadHeader(r)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, pktLen)
|
||||||
|
return b.ReadBody(r, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadHeader attempts to read the next message header from the passed
|
||||||
|
// io.Reader. The header contains the length of the next body including
|
||||||
|
// additional overhead of the MAC. In the case of an authentication error, a
|
||||||
|
// non-nil error is returned.
|
||||||
|
//
|
||||||
|
// NOTE: This method SHOULD NOT be used in the case that the io.Reader may be
|
||||||
|
// adversarial and induce long delays. If the caller needs to set read deadlines
|
||||||
|
// appropriately, it is preferred that they use the split ReadHeader and
|
||||||
|
// ReadBody methods so that the deadlines can be set appropriately on each.
|
||||||
|
func (b *Machine) ReadHeader(r io.Reader) (uint32, error) {
|
||||||
|
_, err := io.ReadFull(r, b.nextCipherHeader[:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
// Attempt to decrypt+auth the packet length present in the stream.
|
// Attempt to decrypt+auth the packet length present in the stream.
|
||||||
pktLenBytes, err := b.recvCipher.Decrypt(
|
pktLenBytes, err := b.recvCipher.Decrypt(
|
||||||
nil, nil, b.nextCipherHeader[:],
|
nil, nil, b.nextCipherHeader[:],
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is the first message being read, take a read buffer from the
|
// Compute the packet length that we will need to read off the wire.
|
||||||
// buffer pool. This is delayed until this point to avoid allocating
|
|
||||||
// read buffers until after the peer has successfully completed the
|
|
||||||
// handshake, and is ready to begin sending lnwire messages.
|
|
||||||
if b.nextCipherText == nil {
|
|
||||||
b.nextCipherText = readBufferPool.Take()
|
|
||||||
runtime.SetFinalizer(b, freeReadBuffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, using the length read from the packet header, read the
|
|
||||||
// encrypted packet itself.
|
|
||||||
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
|
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
|
||||||
if _, err := io.ReadFull(r, b.nextCipherText[:pktLen]); err != nil {
|
|
||||||
|
return pktLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadBody attempts to ready the next message body from the passed io.Reader.
|
||||||
|
// The provided buffer MUST be the length indicated by the packet length
|
||||||
|
// returned by the preceding call to ReadHeader. In the case of an
|
||||||
|
// authentication eerror, a non-nil error is returned.
|
||||||
|
func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) {
|
||||||
|
// Next, using the length read from the packet header, read the
|
||||||
|
// encrypted packet itself into the buffer allocated by the read
|
||||||
|
// pool.
|
||||||
|
_, err := io.ReadFull(r, buf)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finally, decrypt the message held in the buffer, and return a
|
||||||
|
// new byte slice containing the plaintext.
|
||||||
// TODO(roasbeef): modify to let pass in slice
|
// TODO(roasbeef): modify to let pass in slice
|
||||||
return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen])
|
return b.recvCipher.Decrypt(nil, nil, buf)
|
||||||
}
|
|
||||||
|
|
||||||
// freeReadBuffer returns the Machine's read buffer back to the package wide
|
|
||||||
// read buffer pool.
|
|
||||||
//
|
|
||||||
// NOTE: This method should only be called by a Machine's finalizer.
|
|
||||||
func freeReadBuffer(b *Machine) {
|
|
||||||
readBufferPool.Return(b.nextCipherText)
|
|
||||||
b.nextCipherText = nil
|
|
||||||
}
|
}
|
||||||
|
105
peer.go
105
peer.go
@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/lnpeer"
|
"github.com/lightningnetwork/lnd/lnpeer"
|
||||||
"github.com/lightningnetwork/lnd/lnwallet"
|
"github.com/lightningnetwork/lnd/lnwallet"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/pool"
|
||||||
"github.com/lightningnetwork/lnd/ticker"
|
"github.com/lightningnetwork/lnd/ticker"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,8 +44,12 @@ const (
|
|||||||
// idleTimeout is the duration of inactivity before we time out a peer.
|
// idleTimeout is the duration of inactivity before we time out a peer.
|
||||||
idleTimeout = 5 * time.Minute
|
idleTimeout = 5 * time.Minute
|
||||||
|
|
||||||
// writeMessageTimeout is the timeout used when writing a message to peer.
|
// writeMessageTimeout is the timeout used when writing a message to a peer.
|
||||||
writeMessageTimeout = 50 * time.Second
|
writeMessageTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// readMessageTimeout is the timeout used when reading a message from a
|
||||||
|
// peer.
|
||||||
|
readMessageTimeout = 5 * time.Second
|
||||||
|
|
||||||
// handshakeTimeout is the timeout used when waiting for peer init message.
|
// handshakeTimeout is the timeout used when waiting for peer init message.
|
||||||
handshakeTimeout = 15 * time.Second
|
handshakeTimeout = 15 * time.Second
|
||||||
@ -209,11 +214,13 @@ type peer struct {
|
|||||||
// TODO(halseth): remove when link failure is properly handled.
|
// TODO(halseth): remove when link failure is properly handled.
|
||||||
failedChannels map[lnwire.ChannelID]struct{}
|
failedChannels map[lnwire.ChannelID]struct{}
|
||||||
|
|
||||||
// writeBuf is a buffer that we'll re-use in order to encode wire
|
// writePool is the task pool to that manages reuse of write buffers.
|
||||||
// messages to write out directly on the socket. By re-using this
|
// Write tasks are submitted to the pool in order to conserve the total
|
||||||
// buffer, we avoid needing to allocate more memory each time a new
|
// number of write buffers allocated at any one time, and decouple write
|
||||||
// message is to be sent to a peer.
|
// buffer allocation from the peer life cycle.
|
||||||
writeBuf *buffer.Write
|
writePool *pool.Write
|
||||||
|
|
||||||
|
readPool *pool.Read
|
||||||
|
|
||||||
queueQuit chan struct{}
|
queueQuit chan struct{}
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
@ -258,7 +265,8 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server,
|
|||||||
|
|
||||||
chanActiveTimeout: chanActiveTimeout,
|
chanActiveTimeout: chanActiveTimeout,
|
||||||
|
|
||||||
writeBuf: server.writeBufferPool.Take(),
|
writePool: server.writePool,
|
||||||
|
readPool: server.readPool,
|
||||||
|
|
||||||
queueQuit: make(chan struct{}),
|
queueQuit: make(chan struct{}),
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
@ -608,11 +616,6 @@ func (p *peer) WaitForDisconnect(ready chan struct{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.wg.Wait()
|
p.wg.Wait()
|
||||||
|
|
||||||
// Now that we are certain all active goroutines which could have been
|
|
||||||
// modifying the write buffer have exited, return the buffer to the pool
|
|
||||||
// to be reused.
|
|
||||||
p.server.writeBufferPool.Return(p.writeBuf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect terminates the connection with the remote peer. Additionally, a
|
// Disconnect terminates the connection with the remote peer. Additionally, a
|
||||||
@ -644,11 +647,37 @@ func (p *peer) readNextMessage() (lnwire.Message, error) {
|
|||||||
return nil, fmt.Errorf("brontide.Conn required to read messages")
|
return nil, fmt.Errorf("brontide.Conn required to read messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err := noiseConn.SetReadDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pktLen, err := noiseConn.ReadNextHeader()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// First we'll read the next _full_ message. We do this rather than
|
// First we'll read the next _full_ message. We do this rather than
|
||||||
// reading incrementally from the stream as the Lightning wire protocol
|
// reading incrementally from the stream as the Lightning wire protocol
|
||||||
// is message oriented and allows nodes to pad on additional data to
|
// is message oriented and allows nodes to pad on additional data to
|
||||||
// the message stream.
|
// the message stream.
|
||||||
rawMsg, err := noiseConn.ReadNextMessage()
|
var rawMsg []byte
|
||||||
|
err = p.readPool.Submit(func(buf *buffer.Read) error {
|
||||||
|
// Before reading the body of the message, set the read timeout
|
||||||
|
// accordingly to ensure we don't block other readers using the
|
||||||
|
// pool. We do so only after the task has been scheduled to
|
||||||
|
// ensure the deadline doesn't expire while the message is in
|
||||||
|
// the process of being scheduled.
|
||||||
|
readDeadline := time.Now().Add(readMessageTimeout)
|
||||||
|
readErr := noiseConn.SetReadDeadline(readDeadline)
|
||||||
|
if readErr != nil {
|
||||||
|
return readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen])
|
||||||
|
return readErr
|
||||||
|
})
|
||||||
|
|
||||||
atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg)))
|
atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -1359,33 +1388,33 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
|
|||||||
|
|
||||||
p.logWireMessage(msg, false)
|
p.logWireMessage(msg, false)
|
||||||
|
|
||||||
// We'll re-slice of static write buffer to allow this new message to
|
var n int
|
||||||
// utilize all available space. We also ensure we cap the capacity of
|
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
|
||||||
// this new buffer to the static buffer which is sized for the largest
|
// Using a buffer allocated by the write pool, encode the
|
||||||
// possible protocol message.
|
// message directly into the buffer.
|
||||||
b := bytes.NewBuffer(p.writeBuf[0:0:len(p.writeBuf)])
|
_, writeErr := lnwire.WriteMessage(buf, msg, 0)
|
||||||
|
if writeErr != nil {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
// With the temp buffer created and sliced properly (length zero, full
|
// Ensure the write deadline is set before we attempt to send
|
||||||
// capacity), we'll now encode the message directly into this buffer.
|
// the message.
|
||||||
_, err := lnwire.WriteMessage(b, msg, 0)
|
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||||
if err != nil {
|
writeErr = p.conn.SetWriteDeadline(writeDeadline)
|
||||||
return err
|
if writeErr != nil {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, write the message itself in a single swoop.
|
||||||
|
n, writeErr = p.conn.Write(buf.Bytes())
|
||||||
|
return writeErr
|
||||||
|
})
|
||||||
|
|
||||||
|
// Record the number of bytes written on the wire, if any.
|
||||||
|
if n > 0 {
|
||||||
|
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute and set the write deadline we will impose on the remote peer.
|
|
||||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
|
||||||
err = p.conn.SetWriteDeadline(writeDeadline)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, write the message itself in a single swoop.
|
|
||||||
n, err := p.conn.Write(b.Bytes())
|
|
||||||
|
|
||||||
// Regardless of the error returned, record how many bytes were written
|
|
||||||
// to the wire.
|
|
||||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
87
pool/read.go
Normal file
87
pool/read.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/buffer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read is a worker pool specifically designed for sharing access to buffer.Read
|
||||||
|
// objects amongst a set of worker goroutines. This enables an application to
|
||||||
|
// limit the total number of buffer.Read objects allocated at any given time.
|
||||||
|
type Read struct {
|
||||||
|
workerPool *Worker
|
||||||
|
bufferPool *ReadBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRead creates a new Read pool, using an underlying ReadBuffer pool to
|
||||||
|
// recycle buffer.Read objects across the lifetime of the Read pool's workers.
|
||||||
|
func NewRead(readBufferPool *ReadBuffer, numWorkers int,
|
||||||
|
workerTimeout time.Duration) *Read {
|
||||||
|
|
||||||
|
r := &Read{
|
||||||
|
bufferPool: readBufferPool,
|
||||||
|
}
|
||||||
|
r.workerPool = NewWorker(&WorkerConfig{
|
||||||
|
NewWorkerState: r.newWorkerState,
|
||||||
|
NumWorkers: numWorkers,
|
||||||
|
WorkerTimeout: workerTimeout,
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start safely spins up the Read pool.
|
||||||
|
func (r *Read) Start() error {
|
||||||
|
return r.workerPool.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop safely shuts down the Read pool.
|
||||||
|
func (r *Read) Stop() error {
|
||||||
|
return r.workerPool.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit accepts a function closure that provides access to the fresh
|
||||||
|
// buffer.Read object. The function's execution will be allocated to one of the
|
||||||
|
// underlying Worker pool's goroutines.
|
||||||
|
func (r *Read) Submit(inner func(*buffer.Read) error) error {
|
||||||
|
return r.workerPool.Submit(func(s WorkerState) error {
|
||||||
|
state := s.(*readWorkerState)
|
||||||
|
return inner(state.readBuf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// readWorkerState is the per-goroutine state maintained by a Read pool's
|
||||||
|
// goroutines.
|
||||||
|
type readWorkerState struct {
|
||||||
|
// bufferPool is the pool to which the readBuf will be returned when the
|
||||||
|
// goroutine exits.
|
||||||
|
bufferPool *ReadBuffer
|
||||||
|
|
||||||
|
// readBuf is a buffer taken from the bufferPool on initialization,
|
||||||
|
// which will be cleaned and provided to any tasks that the goroutine
|
||||||
|
// processes before exiting.
|
||||||
|
readBuf *buffer.Read
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWorkerState initializes a new readWorkerState, which will be called
|
||||||
|
// whenever a new goroutine is allocated to begin processing read tasks.
|
||||||
|
func (r *Read) newWorkerState() WorkerState {
|
||||||
|
return &readWorkerState{
|
||||||
|
bufferPool: r.bufferPool,
|
||||||
|
readBuf: r.bufferPool.Take(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup returns the readBuf to the underlying buffer pool, and removes the
|
||||||
|
// goroutine's reference to the readBuf.
|
||||||
|
func (r *readWorkerState) Cleanup() {
|
||||||
|
r.bufferPool.Return(r.readBuf)
|
||||||
|
r.readBuf = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset recycles the readBuf to make it ready for any subsequent tasks the
|
||||||
|
// goroutine may process.
|
||||||
|
func (r *readWorkerState) Reset() {
|
||||||
|
r.readBuf.Recycle()
|
||||||
|
}
|
250
pool/worker.go
Normal file
250
pool/worker.go
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrWorkerPoolExiting signals that a shutdown of the Worker has been
|
||||||
|
// requested.
|
||||||
|
var ErrWorkerPoolExiting = errors.New("worker pool exiting")
|
||||||
|
|
||||||
|
// DefaultWorkerTimeout is the default duration after which a worker goroutine
|
||||||
|
// will exit to free up resources after having received no newly submitted
|
||||||
|
// tasks.
|
||||||
|
const DefaultWorkerTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
type (
|
||||||
|
// WorkerState is an interface used by the Worker to abstract the
|
||||||
|
// lifecycle of internal state used by a worker goroutine.
|
||||||
|
WorkerState interface {
|
||||||
|
// Reset clears any internal state that may have been dirtied in
|
||||||
|
// processing a prior task.
|
||||||
|
Reset()
|
||||||
|
|
||||||
|
// Cleanup releases any shared state before a worker goroutine
|
||||||
|
// exits.
|
||||||
|
Cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerConfig parameterizes the behavior of a Worker pool.
|
||||||
|
WorkerConfig struct {
|
||||||
|
// NewWorkerState allocates a new state for a worker goroutine.
|
||||||
|
// This method is called each time a new worker goroutine is
|
||||||
|
// spawned by the pool.
|
||||||
|
NewWorkerState func() WorkerState
|
||||||
|
|
||||||
|
// NumWorkers is the maximum number of workers the Worker pool
|
||||||
|
// will permit to be allocated. Once the maximum number is
|
||||||
|
// reached, any newly submitted tasks are forced to be processed
|
||||||
|
// by existing worker goroutines.
|
||||||
|
NumWorkers int
|
||||||
|
|
||||||
|
// WorkerTimeout is the duration after which a worker goroutine
|
||||||
|
// will exit after having received no newly submitted tasks.
|
||||||
|
WorkerTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Worker maintains a pool of goroutines that process submitted function
|
||||||
|
// closures, and enable more efficient reuse of expensive state.
|
||||||
|
Worker struct {
|
||||||
|
started sync.Once
|
||||||
|
stopped sync.Once
|
||||||
|
|
||||||
|
cfg *WorkerConfig
|
||||||
|
|
||||||
|
// requests is a channel where new tasks are submitted. Tasks
|
||||||
|
// submitted through this channel may cause a new worker
|
||||||
|
// goroutine to be allocated.
|
||||||
|
requests chan *request
|
||||||
|
|
||||||
|
// work is a channel where new tasks are submitted, but is only
|
||||||
|
// read by active worker gorotuines.
|
||||||
|
work chan *request
|
||||||
|
|
||||||
|
// workerSem is a channel-based sempahore that is used to limit
|
||||||
|
// the total number of worker goroutines to the number
|
||||||
|
// prescribed by the WorkerConfig.
|
||||||
|
workerSem chan struct{}
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// request is a tuple of task closure and error channel that is used to
|
||||||
|
// both submit a task to the pool and respond with any errors
|
||||||
|
// encountered during the task's execution.
|
||||||
|
request struct {
|
||||||
|
fn func(WorkerState) error
|
||||||
|
errChan chan error
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWorker initializes a new Worker pool using the provided WorkerConfig.
|
||||||
|
func NewWorker(cfg *WorkerConfig) *Worker {
|
||||||
|
return &Worker{
|
||||||
|
cfg: cfg,
|
||||||
|
requests: make(chan *request),
|
||||||
|
workerSem: make(chan struct{}, cfg.NumWorkers),
|
||||||
|
work: make(chan *request),
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start safely spins up the Worker pool.
|
||||||
|
func (w *Worker) Start() error {
|
||||||
|
w.started.Do(func() {
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.requestHandler()
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop safely shuts down the Worker pool.
|
||||||
|
func (w *Worker) Stop() error {
|
||||||
|
w.stopped.Do(func() {
|
||||||
|
close(w.quit)
|
||||||
|
w.wg.Wait()
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit accepts a function closure to the worker pool. The returned error will
|
||||||
|
// be either the result of the closure's execution or an ErrWorkerPoolExiting if
|
||||||
|
// a shutdown is requested.
|
||||||
|
func (w *Worker) Submit(fn func(WorkerState) error) error {
|
||||||
|
req := &request{
|
||||||
|
fn: fn,
|
||||||
|
errChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
|
||||||
|
// Send request to requestHandler, where either a new worker is spawned
|
||||||
|
// or the task will be handed to an existing worker.
|
||||||
|
case w.requests <- req:
|
||||||
|
|
||||||
|
// Fast path directly to existing worker.
|
||||||
|
case w.work <- req:
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return ErrWorkerPoolExiting
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
|
||||||
|
// Wait for task to be processed.
|
||||||
|
case err := <-req.errChan:
|
||||||
|
return err
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return ErrWorkerPoolExiting
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestHandler processes incoming tasks by either allocating new worker
|
||||||
|
// goroutines to process the incoming tasks, or by feeding a submitted task to
|
||||||
|
// an already running worker goroutine.
|
||||||
|
func (w *Worker) requestHandler() {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req := <-w.requests:
|
||||||
|
select {
|
||||||
|
|
||||||
|
// If we have not reached our maximum number of workers,
|
||||||
|
// spawn one to process the submitted request.
|
||||||
|
case w.workerSem <- struct{}{}:
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.spawnWorker(req)
|
||||||
|
|
||||||
|
// Otherwise, submit the task to any of the active
|
||||||
|
// workers.
|
||||||
|
case w.work <- req:
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnWorker is used when the Worker pool wishes to create a new worker
|
||||||
|
// goroutine. The worker's state is initialized by calling the config's
|
||||||
|
// NewWorkerState method, and will continue to process incoming tasks until the
|
||||||
|
// pool is shut down or no new tasks are received before the worker's timeout
|
||||||
|
// elapses.
|
||||||
|
//
|
||||||
|
// NOTE: This method MUST be run as a goroutine.
|
||||||
|
func (w *Worker) spawnWorker(req *request) {
|
||||||
|
defer w.wg.Done()
|
||||||
|
defer func() { <-w.workerSem }()
|
||||||
|
|
||||||
|
state := w.cfg.NewWorkerState()
|
||||||
|
defer state.Cleanup()
|
||||||
|
|
||||||
|
req.errChan <- req.fn(state)
|
||||||
|
|
||||||
|
// We'll use a timer to implement the worker timeouts, as this reduces
|
||||||
|
// the number of total allocations that would otherwise be necessary
|
||||||
|
// with time.After.
|
||||||
|
var t *time.Timer
|
||||||
|
for {
|
||||||
|
// Before processing another request, we'll reset the worker
|
||||||
|
// state to that each request is processed against a clean
|
||||||
|
// state.
|
||||||
|
state.Reset()
|
||||||
|
|
||||||
|
select {
|
||||||
|
|
||||||
|
// Process any new requests that get submitted. We use a
|
||||||
|
// non-blocking case first so that under high load we can spare
|
||||||
|
// allocating a timeout.
|
||||||
|
case req := <-w.work:
|
||||||
|
req.errChan <- req.fn(state)
|
||||||
|
continue
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// There were no new requests that could be taken immediately
|
||||||
|
// from the work channel. Initialize or reset the timeout, which
|
||||||
|
// will fire if the worker doesn't receive a new task before
|
||||||
|
// needing to exit.
|
||||||
|
if t != nil {
|
||||||
|
t.Reset(w.cfg.WorkerTimeout)
|
||||||
|
} else {
|
||||||
|
t = time.NewTimer(w.cfg.WorkerTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
|
||||||
|
// Process any new requests that get submitted.
|
||||||
|
case req := <-w.work:
|
||||||
|
req.errChan <- req.fn(state)
|
||||||
|
|
||||||
|
// Stop the timer, draining the timer's channel if a
|
||||||
|
// notification was already delivered.
|
||||||
|
if !t.Stop() {
|
||||||
|
<-t.C
|
||||||
|
}
|
||||||
|
|
||||||
|
// The timeout has elapsed, meaning the worker did not receive
|
||||||
|
// any new tasks. Exit to allow the worker to return and free
|
||||||
|
// its resources.
|
||||||
|
case <-t.C:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-w.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
353
pool/worker_test.go
Normal file
353
pool/worker_test.go
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
package pool_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/buffer"
|
||||||
|
"github.com/lightningnetwork/lnd/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type workerPoolTest struct {
|
||||||
|
name string
|
||||||
|
newPool func() interface{}
|
||||||
|
numWorkers int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcreteWorkerPools asserts the behavior of any concrete implementations
|
||||||
|
// of worker pools provided by the pool package. Currently this tests the
|
||||||
|
// pool.Read and pool.Write instances.
|
||||||
|
func TestConcreteWorkerPools(t *testing.T) {
|
||||||
|
const (
|
||||||
|
gcInterval = time.Second
|
||||||
|
expiryInterval = 250 * time.Millisecond
|
||||||
|
numWorkers = 5
|
||||||
|
workerTimeout = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
tests := []workerPoolTest{
|
||||||
|
{
|
||||||
|
name: "write pool",
|
||||||
|
newPool: func() interface{} {
|
||||||
|
bp := pool.NewWriteBuffer(
|
||||||
|
gcInterval, expiryInterval,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pool.NewWrite(
|
||||||
|
bp, numWorkers, workerTimeout,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
numWorkers: numWorkers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "read pool",
|
||||||
|
newPool: func() interface{} {
|
||||||
|
bp := pool.NewReadBuffer(
|
||||||
|
gcInterval, expiryInterval,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pool.NewRead(
|
||||||
|
bp, numWorkers, workerTimeout,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
numWorkers: numWorkers,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
testWorkerPool(t, test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testWorkerPool(t *testing.T, test workerPoolTest) {
|
||||||
|
t.Run(test.name+" non blocking", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
p := test.newPool()
|
||||||
|
startGeneric(t, p)
|
||||||
|
defer stopGeneric(t, p)
|
||||||
|
|
||||||
|
submitNonblockingGeneric(t, p, test.numWorkers)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(test.name+" blocking", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
p := test.newPool()
|
||||||
|
startGeneric(t, p)
|
||||||
|
defer stopGeneric(t, p)
|
||||||
|
|
||||||
|
submitBlockingGeneric(t, p, test.numWorkers)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(test.name+" partial blocking", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
p := test.newPool()
|
||||||
|
startGeneric(t, p)
|
||||||
|
defer stopGeneric(t, p)
|
||||||
|
|
||||||
|
submitPartialBlockingGeneric(t, p, test.numWorkers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// submitNonblockingGeneric asserts that queueing tasks to the worker pool and
|
||||||
|
// allowing them all to unblock simultaneously results in all of the tasks being
|
||||||
|
// completed in a timely manner.
|
||||||
|
func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||||
|
// We'll submit 2*nWorkers tasks that will all be unblocked
|
||||||
|
// simultaneously.
|
||||||
|
nUnblocked := 2 * nWorkers
|
||||||
|
|
||||||
|
// First we'll queue all of the tasks for the pool.
|
||||||
|
errChan := make(chan error)
|
||||||
|
semChan := make(chan struct{})
|
||||||
|
for i := 0; i < nUnblocked; i++ {
|
||||||
|
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since we haven't signaled the semaphore, none of the them should
|
||||||
|
// complete.
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
// Now, unblock them all simultaneously. All of the tasks should then be
|
||||||
|
// processed in parallel. Afterward, no more errors should come through.
|
||||||
|
close(semChan)
|
||||||
|
pullParllel(t, nUnblocked, errChan)
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// submitBlockingGeneric asserts that submitting blocking tasks to the pool and
|
||||||
|
// unblocking each sequentially results in a single task being processed at a
|
||||||
|
// time.
|
||||||
|
func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||||
|
// We'll submit 2*nWorkers tasks that will be unblocked sequentially.
|
||||||
|
nBlocked := 2 * nWorkers
|
||||||
|
|
||||||
|
// First, queue all of the blocking tasks for the pool.
|
||||||
|
errChan := make(chan error)
|
||||||
|
semChan := make(chan struct{})
|
||||||
|
for i := 0; i < nBlocked; i++ {
|
||||||
|
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since we haven't signaled the semaphore, none of them should
|
||||||
|
// complete.
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
// Now, pull each blocking task sequentially from the pool. Afterwards,
|
||||||
|
// no more errors should come through.
|
||||||
|
pullSequntial(t, nBlocked, errChan, semChan)
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// submitPartialBlockingGeneric tests that so long as one worker is not blocked,
|
||||||
|
// any other non-blocking submitted tasks can still be processed.
|
||||||
|
func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||||
|
// We'll submit nWorkers-1 tasks that will be initially blocked, the
|
||||||
|
// remainder will all be unblocked simultaneously. After the unblocked
|
||||||
|
// tasks have finished, we will sequentially unblock the nWorkers-1
|
||||||
|
// tasks that were first submitted.
|
||||||
|
nBlocked := nWorkers - 1
|
||||||
|
nUnblocked := 2*nWorkers - nBlocked
|
||||||
|
|
||||||
|
// First, submit all of the blocking tasks to the pool.
|
||||||
|
errChan := make(chan error)
|
||||||
|
semChan := make(chan struct{})
|
||||||
|
for i := 0; i < nBlocked; i++ {
|
||||||
|
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since these are all blocked, no errors should be returned yet.
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
// Now, add all of the non-blocking task to the pool.
|
||||||
|
semChanNB := make(chan struct{})
|
||||||
|
for i := 0; i < nUnblocked; i++ {
|
||||||
|
go func() { errChan <- submitGeneric(p, semChanNB) }()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since we haven't unblocked the second batch, we again expect no tasks
|
||||||
|
// to finish.
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
// Now, unblock the unblocked task and pull all of them. After they have
|
||||||
|
// been pulled, we should see no more tasks.
|
||||||
|
close(semChanNB)
|
||||||
|
pullParllel(t, nUnblocked, errChan)
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
|
||||||
|
// Finally, unblock each the blocked tasks we added initially, and
|
||||||
|
// assert that no further errors come through.
|
||||||
|
pullSequntial(t, nBlocked, errChan, semChan)
|
||||||
|
pullNothing(t, errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullNothing(t *testing.T, errChan chan error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
t.Fatalf("received unexpected error before semaphore "+
|
||||||
|
"release: %v", err)
|
||||||
|
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullParllel(t *testing.T, n int, errChan chan error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("task %d was not processed in time", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
// Signal for another task to unblock.
|
||||||
|
select {
|
||||||
|
case semChan <- struct{}{}:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("task %d was not unblocked", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the error to arrive, we expect it to be non-nil.
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("task %d was not processed in time", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startGeneric(t *testing.T, p interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch pp := p.(type) {
|
||||||
|
case *pool.Write:
|
||||||
|
err = pp.Start()
|
||||||
|
|
||||||
|
case *pool.Read:
|
||||||
|
err = pp.Start()
|
||||||
|
|
||||||
|
default:
|
||||||
|
t.Fatalf("unknown worker pool type: %T", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to start worker pool: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopGeneric(t *testing.T, p interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch pp := p.(type) {
|
||||||
|
case *pool.Write:
|
||||||
|
err = pp.Stop()
|
||||||
|
|
||||||
|
case *pool.Read:
|
||||||
|
err = pp.Stop()
|
||||||
|
|
||||||
|
default:
|
||||||
|
t.Fatalf("unknown worker pool type: %T", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to stop worker pool: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func submitGeneric(p interface{}, sem <-chan struct{}) error {
|
||||||
|
var err error
|
||||||
|
switch pp := p.(type) {
|
||||||
|
case *pool.Write:
|
||||||
|
err = pp.Submit(func(buf *bytes.Buffer) error {
|
||||||
|
// Verify that the provided buffer has been reset to be
|
||||||
|
// zero length.
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
return fmt.Errorf("buf should be length zero, "+
|
||||||
|
"instead has length %d", buf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the capacity of the buffer has the
|
||||||
|
// correct underlying size of a buffer.WriteSize.
|
||||||
|
if buf.Cap() != buffer.WriteSize {
|
||||||
|
return fmt.Errorf("buf should have capacity "+
|
||||||
|
"%d, instead has capacity %d",
|
||||||
|
buffer.WriteSize, buf.Cap())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample some random bytes that we'll use to dirty the
|
||||||
|
// buffer.
|
||||||
|
b := make([]byte, rand.Intn(buf.Cap()))
|
||||||
|
_, err := io.ReadFull(crand.Reader, b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the random bytes the buffer.
|
||||||
|
_, err = buf.Write(b)
|
||||||
|
|
||||||
|
// Wait until this task is signaled to exit.
|
||||||
|
<-sem
|
||||||
|
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
case *pool.Read:
|
||||||
|
err = pp.Submit(func(buf *buffer.Read) error {
|
||||||
|
// Assert that all of the bytes in the provided array
|
||||||
|
// are zero, indicating that the buffer was reset
|
||||||
|
// between uses.
|
||||||
|
for i := range buf[:] {
|
||||||
|
if buf[i] != 0x00 {
|
||||||
|
return fmt.Errorf("byte %d of "+
|
||||||
|
"buffer.Read should be "+
|
||||||
|
"0, instead is %d", i, buf[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample some random bytes to read into the buffer.
|
||||||
|
_, err := io.ReadFull(crand.Reader, buf[:])
|
||||||
|
|
||||||
|
// Wait until this task is signaled to exit.
|
||||||
|
<-sem
|
||||||
|
|
||||||
|
return err
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown worker pool type: %T", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to submit task: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
100
pool/write.go
Normal file
100
pool/write.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/buffer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Write is a worker pool specifically designed for sharing access to
|
||||||
|
// buffer.Write objects amongst a set of worker goroutines. This enables an
|
||||||
|
// application to limit the total number of buffer.Write objects allocated at
|
||||||
|
// any given time.
|
||||||
|
type Write struct {
|
||||||
|
workerPool *Worker
|
||||||
|
bufferPool *WriteBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWrite creates a Write pool, using an underlying Writebuffer pool to
|
||||||
|
// recycle buffer.Write objects accross the lifetime of the Write pool's
|
||||||
|
// workers.
|
||||||
|
func NewWrite(writeBufferPool *WriteBuffer, numWorkers int,
|
||||||
|
workerTimeout time.Duration) *Write {
|
||||||
|
|
||||||
|
w := &Write{
|
||||||
|
bufferPool: writeBufferPool,
|
||||||
|
}
|
||||||
|
w.workerPool = NewWorker(&WorkerConfig{
|
||||||
|
NewWorkerState: w.newWorkerState,
|
||||||
|
NumWorkers: numWorkers,
|
||||||
|
WorkerTimeout: workerTimeout,
|
||||||
|
})
|
||||||
|
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start safely spins up the Write pool.
|
||||||
|
func (w *Write) Start() error {
|
||||||
|
return w.workerPool.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop safely shuts down the Write pool.
|
||||||
|
func (w *Write) Stop() error {
|
||||||
|
return w.workerPool.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit accepts a function closure that provides access to a fresh
|
||||||
|
// bytes.Buffer backed by a buffer.Write object. The function's execution will
|
||||||
|
// be allocated to one of the underlying Worker pool's goroutines.
|
||||||
|
func (w *Write) Submit(inner func(*bytes.Buffer) error) error {
|
||||||
|
return w.workerPool.Submit(func(s WorkerState) error {
|
||||||
|
state := s.(*writeWorkerState)
|
||||||
|
return inner(state.buf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeWorkerState is the per-goroutine state maintained by a Write pool's
|
||||||
|
// goroutines.
|
||||||
|
type writeWorkerState struct {
|
||||||
|
// bufferPool is the pool to which the writeBuf will be returned when
|
||||||
|
// the goroutine exits.
|
||||||
|
bufferPool *WriteBuffer
|
||||||
|
|
||||||
|
// writeBuf is the buffer taken from the bufferPool on initialization,
|
||||||
|
// which will be used to back the buf object provided to any tasks that
|
||||||
|
// the goroutine processes before exiting.
|
||||||
|
writeBuf *buffer.Write
|
||||||
|
|
||||||
|
// buf is a buffer backed by writeBuf, that can be written to by tasks
|
||||||
|
// submitted to the Write pool. The buf will be reset between each task
|
||||||
|
// processed by a goroutine before exiting, and allows the task
|
||||||
|
// submitters to interact with the writeBuf as if it were an io.Writer.
|
||||||
|
buf *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWorkerState initializes a new writeWorkerState, which will be called
|
||||||
|
// whenever a new goroutine is allocated to begin processing write tasks.
|
||||||
|
func (w *Write) newWorkerState() WorkerState {
|
||||||
|
writeBuf := w.bufferPool.Take()
|
||||||
|
|
||||||
|
return &writeWorkerState{
|
||||||
|
bufferPool: w.bufferPool,
|
||||||
|
writeBuf: writeBuf,
|
||||||
|
buf: bytes.NewBuffer(writeBuf[0:0:len(writeBuf)]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup returns the writeBuf to the underlying buffer pool, and removes the
|
||||||
|
// goroutine's reference to the readBuf and encapsulating buf.
|
||||||
|
func (w *writeWorkerState) Cleanup() {
|
||||||
|
w.bufferPool.Return(w.writeBuf)
|
||||||
|
w.writeBuf = nil
|
||||||
|
w.buf = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the bytes.Buffer so that it is zero-length and has the capacity
|
||||||
|
// of the underlying buffer.Write.k
|
||||||
|
func (w *writeWorkerState) Reset() {
|
||||||
|
w.buf.Reset()
|
||||||
|
}
|
38
server.go
38
server.go
@ -171,7 +171,9 @@ type server struct {
|
|||||||
|
|
||||||
sigPool *lnwallet.SigPool
|
sigPool *lnwallet.SigPool
|
||||||
|
|
||||||
writeBufferPool *pool.WriteBuffer
|
writePool *pool.Write
|
||||||
|
|
||||||
|
readPool *pool.Read
|
||||||
|
|
||||||
// globalFeatures feature vector which affects HTLCs and thus are also
|
// globalFeatures feature vector which affects HTLCs and thus are also
|
||||||
// advertised to other nodes.
|
// advertised to other nodes.
|
||||||
@ -263,16 +265,31 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
|
|||||||
sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db")
|
sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db")
|
||||||
replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier)
|
replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier)
|
||||||
sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog)
|
sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog)
|
||||||
|
|
||||||
writeBufferPool := pool.NewWriteBuffer(
|
writeBufferPool := pool.NewWriteBuffer(
|
||||||
pool.DefaultWriteBufferGCInterval,
|
pool.DefaultWriteBufferGCInterval,
|
||||||
pool.DefaultWriteBufferExpiryInterval,
|
pool.DefaultWriteBufferExpiryInterval,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
writePool := pool.NewWrite(
|
||||||
|
writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
readBufferPool := pool.NewReadBuffer(
|
||||||
|
pool.DefaultReadBufferGCInterval,
|
||||||
|
pool.DefaultReadBufferExpiryInterval,
|
||||||
|
)
|
||||||
|
|
||||||
|
readPool := pool.NewRead(
|
||||||
|
readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
chanDB: chanDB,
|
chanDB: chanDB,
|
||||||
cc: cc,
|
cc: cc,
|
||||||
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
|
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
|
||||||
writeBufferPool: writeBufferPool,
|
writePool: writePool,
|
||||||
|
readPool: readPool,
|
||||||
|
|
||||||
invoices: invoices.NewRegistry(chanDB, activeNetParams.Params),
|
invoices: invoices.NewRegistry(chanDB, activeNetParams.Params),
|
||||||
|
|
||||||
@ -1010,6 +1027,12 @@ func (s *server) Start() error {
|
|||||||
if err := s.sigPool.Start(); err != nil {
|
if err := s.sigPool.Start(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := s.writePool.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.readPool.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := s.cc.chainNotifier.Start(); err != nil {
|
if err := s.cc.chainNotifier.Start(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1102,7 +1125,6 @@ func (s *server) Stop() error {
|
|||||||
|
|
||||||
// Shutdown the wallet, funding manager, and the rpc server.
|
// Shutdown the wallet, funding manager, and the rpc server.
|
||||||
s.chanStatusMgr.Stop()
|
s.chanStatusMgr.Stop()
|
||||||
s.sigPool.Stop()
|
|
||||||
s.cc.chainNotifier.Stop()
|
s.cc.chainNotifier.Stop()
|
||||||
s.chanRouter.Stop()
|
s.chanRouter.Stop()
|
||||||
s.htlcSwitch.Stop()
|
s.htlcSwitch.Stop()
|
||||||
@ -1129,6 +1151,10 @@ func (s *server) Stop() error {
|
|||||||
// Wait for all lingering goroutines to quit.
|
// Wait for all lingering goroutines to quit.
|
||||||
s.wg.Wait()
|
s.wg.Wait()
|
||||||
|
|
||||||
|
s.sigPool.Stop()
|
||||||
|
s.writePool.Stop()
|
||||||
|
s.readPool.Stop()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user