From c2a6c86e6b7c168750dab95e8501fa4681e8945c Mon Sep 17 00:00:00 2001 From: nsa Date: Wed, 7 Aug 2019 22:17:50 -0400 Subject: [PATCH] rpcserver: adding ChannelAcceptor bidirectional streaming --- lnd.go | 6 +- rpcserver.go | 188 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 189 insertions(+), 5 deletions(-) diff --git a/lnd.go b/lnd.go index 982e672b0..ee260bfdb 100644 --- a/lnd.go +++ b/lnd.go @@ -43,6 +43,7 @@ import ( "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" @@ -488,6 +489,9 @@ func Main(lisCfg ListenerCfg) error { } } + // Initialize the ChainedAcceptor. + chainedAcceptor := chanacceptor.NewChainedAcceptor() + // Set up the core server which will listen for incoming peer // connections. server, err := newServer( @@ -547,7 +551,7 @@ func Main(lisCfg ListenerCfg) error { rpcServer, err := newRPCServer( server, macaroonService, cfg.SubRPCServers, restDialOpts, restProxyDest, atplManager, server.invoices, tower, tlsCfg, - rpcListeners, + rpcListeners, chainedAcceptor, ) if err != nil { err := fmt.Errorf("Unable to create RPC server: %v", err) diff --git a/rpcserver.go b/rpcserver.go index 97667a6cc..42f1cbab9 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -76,6 +77,12 @@ var ( // It is set to the value under the Bitcoin chain as default. MaxPaymentMSat = maxBtcPaymentMSat + // defaultAcceptorTimeout is the time after which an RPCAcceptor will time + // out and return false if it hasn't yet received a response. + // + // TODO: Make this configurable + defaultAcceptorTimeout = 15 * time.Second + // readPermissions is a slice of all entities that allow read // permissions for authorization purposes, all lowercase. readPermissions = []bakery.Op{ @@ -382,6 +389,13 @@ func mainRPCServerPermissions() map[string][]bakery.Op { Entity: "offchain", Action: "read", }}, + "/lnrpc.Lightning/ChannelAcceptor": {{ + Entity: "onchain", + Action: "write", + }, { + Entity: "offchain", + Action: "write", + }}, } } @@ -430,6 +444,10 @@ type rpcServer struct { // rpc sub server. routerBackend *routerrpc.RouterBackend + // chanPredicate is used in the bidirectional ChannelAcceptor streaming + // method. + chanPredicate *chanacceptor.ChainedAcceptor + quit chan struct{} } @@ -446,7 +464,8 @@ func newRPCServer(s *server, macService *macaroons.Service, subServerCgs *subRPCServerConfigs, restDialOpts []grpc.DialOption, restProxyDest string, atpl *autopilot.Manager, invoiceRegistry *invoices.InvoiceRegistry, tower *watchtower.Standalone, - tlsCfg *tls.Config, getListeners rpcListeners) (*rpcServer, error) { + tlsCfg *tls.Config, getListeners rpcListeners, + chanPredicate *chanacceptor.ChainedAcceptor) (*rpcServer, error) { // Set up router rpc backend. channelGraph := s.chanDB.ChannelGraph() @@ -601,6 +620,7 @@ func newRPCServer(s *server, macService *macaroons.Service, grpcServer: grpcServer, server: s, routerBackend: routerBackend, + chanPredicate: chanPredicate, quit: make(chan struct{}, 1), } lnrpc.RegisterLightningServer(grpcServer, rootRPCServer) @@ -5052,7 +5072,167 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription } } -// ChannelAcceptor method stub. -func (r *rpcServer) ChannelAcceptor(stream lnrpc.Lightning_ChannelAcceptorServer) error { - return nil +// chanAcceptInfo is used in the ChannelAcceptor bidirectional stream and +// encapsulates the request information sent from the RPCAcceptor to the +// RPCServer. +type chanAcceptInfo struct { + chanReq *chanacceptor.ChannelAcceptRequest + responseChan chan bool +} + +// ChannelAcceptor dispatches a bi-directional streaming RPC in which +// OpenChannel requests are sent to the client and the client responds with +// a boolean that tells LND whether or not to accept the channel. This allows +// node operators to specify their own criteria for accepting inbound channels +// through a single persistent connection. +func (r *rpcServer) ChannelAcceptor(stream lnrpc.Lightning_ChannelAcceptorServer) error { + chainedAcceptor := r.chanPredicate + + // Create two channels to handle requests and responses respectively. + newRequests := make(chan *chanAcceptInfo) + responses := make(chan lnrpc.ChannelAcceptResponse) + + // Define a quit channel that will be used to signal to the RPCAcceptor's + // closure whether the stream still exists. + quit := make(chan struct{}) + defer close(quit) + + // demultiplexReq is a closure that will be passed to the RPCAcceptor and + // acts as an intermediary between the RPCAcceptor and the RPCServer. + demultiplexReq := func(req *chanacceptor.ChannelAcceptRequest) bool { + respChan := make(chan bool, 1) + + newRequest := &chanAcceptInfo{ + chanReq: req, + responseChan: respChan, + } + + // timeout is the time after which ChannelAcceptRequests expire. + timeout := time.After(defaultAcceptorTimeout) + + // Send the request to the newRequests channel. + select { + case newRequests <- newRequest: + case <-timeout: + rpcsLog.Errorf("RPCAcceptor returned false - reached timeout of %d", + defaultAcceptorTimeout) + return false + case <-quit: + return false + case <-r.quit: + return false + } + + // Receive the response and return it. If no response has been received + // in defaultAcceptorTimeout, then return false. + select { + case resp := <-respChan: + return resp + case <-timeout: + rpcsLog.Errorf("RPCAcceptor returned false - reached timeout of %d", + defaultAcceptorTimeout) + return false + case <-quit: + return false + case <-r.quit: + return false + } + } + + // Create a new RPCAcceptor via the NewRPCAcceptor method. + rpcAcceptor := chanacceptor.NewRPCAcceptor(demultiplexReq) + + // Add the RPCAcceptor to the ChainedAcceptor and defer its removal. + id := chainedAcceptor.AddAcceptor(rpcAcceptor) + defer chainedAcceptor.RemoveAcceptor(id) + + // errChan is used by the receive loop to signal any errors that occur + // during reading from the stream. This is primarily used to shutdown the + // send loop in the case of an RPC client disconnecting. + errChan := make(chan error, 1) + + // We need to have the stream.Recv() in a goroutine since the call is + // blocking and would prevent us from sending more ChannelAcceptRequests to + // the RPC client. + go func() { + for { + resp, err := stream.Recv() + if err != nil { + errChan <- err + return + } + + var pendingID [32]byte + copy(pendingID[:], resp.PendingChanId) + + openChanResp := lnrpc.ChannelAcceptResponse{ + Accept: resp.Accept, + PendingChanId: pendingID[:], + } + + // Now that we have the response from the RPC client, send it to + // the responses chan. + select { + case responses <- openChanResp: + case <-quit: + return + case <-r.quit: + return + } + } + }() + + acceptRequests := make(map[[32]byte]chan bool) + + for { + select { + case newRequest := <-newRequests: + + req := newRequest.chanReq + pendingChanID := req.OpenChanMsg.PendingChannelID + + acceptRequests[pendingChanID] = newRequest.responseChan + + // A ChannelAcceptRequest has been received, send it to the client. + chanAcceptReq := &lnrpc.ChannelAcceptRequest{ + NodePubkey: req.Node.SerializeCompressed(), + ChainHash: req.OpenChanMsg.ChainHash[:], + PendingChanId: req.OpenChanMsg.PendingChannelID[:], + FundingAmt: uint64(req.OpenChanMsg.FundingAmount), + PushAmt: uint64(req.OpenChanMsg.PushAmount), + DustLimit: uint64(req.OpenChanMsg.DustLimit), + MaxValueInFlight: uint64(req.OpenChanMsg.MaxValueInFlight), + ChannelReserve: uint64(req.OpenChanMsg.ChannelReserve), + MinHtlc: uint64(req.OpenChanMsg.HtlcMinimum), + FeePerKw: uint64(req.OpenChanMsg.FeePerKiloWeight), + CsvDelay: uint32(req.OpenChanMsg.CsvDelay), + MaxAcceptedHtlcs: uint32(req.OpenChanMsg.MaxAcceptedHTLCs), + ChannelFlags: uint32(req.OpenChanMsg.ChannelFlags), + } + + if err := stream.Send(chanAcceptReq); err != nil { + return err + } + case resp := <-responses: + // Look up the appropriate channel to send on given the pending ID. + // If a channel is found, send the response over it. + var pendingID [32]byte + copy(pendingID[:], resp.PendingChanId) + respChan, ok := acceptRequests[pendingID] + if !ok { + continue + } + + // Send the response boolean over the buffered response channel. + respChan <- resp.Accept + + // Delete the channel from the acceptRequests map. + delete(acceptRequests, pendingID) + case err := <-errChan: + rpcsLog.Errorf("Received an error: %v, shutting down", err) + return err + case <-r.quit: + return fmt.Errorf("RPC server is shutting down") + } + } }