mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-08 05:08:13 +02:00
lnd: Move TlsManager files to new files
This commit is contained in:
251
lnd.go
251
lnd.go
@@ -6,8 +6,6 @@ package lnd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@@ -24,7 +22,6 @@ import (
|
|||||||
proxy "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
proxy "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||||
"github.com/lightningnetwork/lnd/autopilot"
|
"github.com/lightningnetwork/lnd/autopilot"
|
||||||
"github.com/lightningnetwork/lnd/build"
|
"github.com/lightningnetwork/lnd/build"
|
||||||
"github.com/lightningnetwork/lnd/cert"
|
|
||||||
"github.com/lightningnetwork/lnd/chanacceptor"
|
"github.com/lightningnetwork/lnd/chanacceptor"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
@@ -38,7 +35,6 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
"github.com/lightningnetwork/lnd/walletunlocker"
|
"github.com/lightningnetwork/lnd/walletunlocker"
|
||||||
"github.com/lightningnetwork/lnd/watchtower"
|
"github.com/lightningnetwork/lnd/watchtower"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
@@ -633,253 +629,6 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSManager generates/renews a TLS certificate when needed and returns the
|
|
||||||
// certificate configuration options needed for gRPC and REST.
|
|
||||||
type TLSManager struct {
|
|
||||||
cfg *Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTLSManager returns a reference to a new TLSManager.
|
|
||||||
func NewTLSManager(cfg *Config) *TLSManager {
|
|
||||||
return &TLSManager{
|
|
||||||
cfg: cfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConfig returns a TLS configuration for the gRPC server and credentials
|
|
||||||
// and a proxy destination for the REST reverse proxy.
|
|
||||||
func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
|
|
||||||
func(net.Addr) (net.Listener, error), func(), error) {
|
|
||||||
|
|
||||||
tlsCfg, cleanUp, err := t.generateOrRenewCert()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we know that we have a ceritificate, let's generate the
|
|
||||||
// required config options.
|
|
||||||
restCreds, err := credentials.NewClientTLSFromFile(
|
|
||||||
t.cfg.TLSCertPath, "",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverCreds := credentials.NewTLS(tlsCfg)
|
|
||||||
serverOpts := []grpc.ServerOption{grpc.Creds(serverCreds)}
|
|
||||||
|
|
||||||
// For our REST dial options, we'll still use TLS, but also increase
|
|
||||||
// the max message size that we'll decode to allow clients to hit
|
|
||||||
// endpoints which return more data such as the DescribeGraph call.
|
|
||||||
// We set this to 200MiB atm. Should be the same value as maxMsgRecvSize
|
|
||||||
// in cmd/lncli/main.go.
|
|
||||||
restDialOpts := []grpc.DialOption{
|
|
||||||
grpc.WithTransportCredentials(restCreds),
|
|
||||||
grpc.WithDefaultCallOptions(
|
|
||||||
grpc.MaxCallRecvMsgSize(lnrpc.MaxGrpcMsgSize),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a function closure that can be used to listen on a given
|
|
||||||
// address with the current TLS config.
|
|
||||||
restListen := func(addr net.Addr) (net.Listener, error) {
|
|
||||||
// For restListen we will call ListenOnAddress if TLS is
|
|
||||||
// disabled.
|
|
||||||
if t.cfg.DisableRestTLS {
|
|
||||||
return lncfg.ListenOnAddress(addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return lncfg.TLSListenOnAddress(addr, tlsCfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return serverOpts, restDialOpts, restListen, cleanUp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateOrRenewCert generates a new TLS certificate if we're not using one
|
|
||||||
// yet or renews it if it's outdated.
|
|
||||||
func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) {
|
|
||||||
// Genereate a TLS pair if we don't have one yet.
|
|
||||||
err := t.createTLSPair()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certData, parsedCert, err := cert.LoadCert(
|
|
||||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if the certificate needs to be renewed. If it does, we
|
|
||||||
// return the newly generated certificate data instead.
|
|
||||||
reloadedCertData, err := t.certMaintenance(parsedCert)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if reloadedCertData != nil {
|
|
||||||
certData = *reloadedCertData
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsCfg := cert.TLSConfFromCert(certData)
|
|
||||||
|
|
||||||
cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg)
|
|
||||||
|
|
||||||
return tlsCfg, cleanUp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a TLS pair doesn't exist yet, create them and write them to disk.
|
|
||||||
func (t *TLSManager) createTLSPair() error {
|
|
||||||
// Ensure we create TLS key and certificate if they don't exist.
|
|
||||||
if fileExists(t.cfg.TLSCertPath) || fileExists(t.cfg.TLSKeyPath) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rpcsLog.Infof("Generating TLS certificates...")
|
|
||||||
err := cert.GenCertPair(
|
|
||||||
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
|
||||||
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
|
||||||
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
|
||||||
t.cfg.TLSCertDuration,
|
|
||||||
)
|
|
||||||
rpcsLog.Infof("Done generating TLS certificates")
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// certMaintenance checks if the certificate IP and domains matches the config,
|
|
||||||
// and renews the certificate if either this data is outdated or the
|
|
||||||
// certificate is expired.
|
|
||||||
func (t *TLSManager) certMaintenance(
|
|
||||||
parsedCert *x509.Certificate) (*tls.Certificate, error) {
|
|
||||||
|
|
||||||
// We check whether the certificate we have on disk match the IPs and
|
|
||||||
// domains specified by the config. If the extra IPs or domains have
|
|
||||||
// changed from when the certificate was created, we will refresh the
|
|
||||||
// certificate if auto refresh is active.
|
|
||||||
refresh := false
|
|
||||||
var err error
|
|
||||||
if t.cfg.TLSAutoRefresh {
|
|
||||||
refresh, err = cert.IsOutdated(
|
|
||||||
parsedCert, t.cfg.TLSExtraIPs,
|
|
||||||
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the certificate expired or it was outdated, delete it and the TLS
|
|
||||||
// key and generate a new pair.
|
|
||||||
if !time.Now().After(parsedCert.NotAfter) && !refresh {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ltndLog.Info("TLS certificate is expired or outdated, " +
|
|
||||||
"generating a new one")
|
|
||||||
|
|
||||||
err = os.Remove(t.cfg.TLSCertPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.Remove(t.cfg.TLSKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rpcsLog.Infof("Renewing TLS certificates...")
|
|
||||||
err = cert.GenCertPair(
|
|
||||||
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
|
||||||
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
|
||||||
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
|
||||||
t.cfg.TLSCertDuration,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rpcsLog.Infof("Done renewing TLS certificates")
|
|
||||||
|
|
||||||
// Reload the certificate data.
|
|
||||||
reloadedCertData, _, err := cert.LoadCert(
|
|
||||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &reloadedCertData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the
|
|
||||||
// option is set.
|
|
||||||
func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate,
|
|
||||||
tlsCfg *tls.Config) func() {
|
|
||||||
|
|
||||||
// If Let's Encrypt is enabled, instantiate autocert to request/renew
|
|
||||||
// the certificates.
|
|
||||||
cleanUp := func() {}
|
|
||||||
if t.cfg.LetsEncryptDomain == "" {
|
|
||||||
return cleanUp
|
|
||||||
}
|
|
||||||
|
|
||||||
ltndLog.Infof("Using Let's Encrypt certificate for domain %v",
|
|
||||||
t.cfg.LetsEncryptDomain)
|
|
||||||
|
|
||||||
manager := autocert.Manager{
|
|
||||||
Cache: autocert.DirCache(t.cfg.LetsEncryptDir),
|
|
||||||
Prompt: autocert.AcceptTOS,
|
|
||||||
HostPolicy: autocert.HostWhitelist(
|
|
||||||
t.cfg.LetsEncryptDomain,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: t.cfg.LetsEncryptListen,
|
|
||||||
Handler: manager.HTTPHandler(nil),
|
|
||||||
}
|
|
||||||
shutdownCompleted := make(chan struct{})
|
|
||||||
cleanUp = func() {
|
|
||||||
err := srv.Shutdown(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
ltndLog.Errorf("Autocert listener shutdown "+
|
|
||||||
" error: %v", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
<-shutdownCompleted
|
|
||||||
ltndLog.Infof("Autocert challenge listener stopped")
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ltndLog.Infof("Autocert challenge listener started "+
|
|
||||||
"at %v", t.cfg.LetsEncryptListen)
|
|
||||||
|
|
||||||
err := srv.ListenAndServe()
|
|
||||||
if err != http.ErrServerClosed {
|
|
||||||
ltndLog.Errorf("autocert http: %v", err)
|
|
||||||
}
|
|
||||||
close(shutdownCompleted)
|
|
||||||
}()
|
|
||||||
|
|
||||||
getCertificate := func(h *tls.ClientHelloInfo) (
|
|
||||||
*tls.Certificate, error) {
|
|
||||||
|
|
||||||
lecert, err := manager.GetCertificate(h)
|
|
||||||
if err != nil {
|
|
||||||
ltndLog.Errorf("GetCertificate: %v", err)
|
|
||||||
return certData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return lecert, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The self-signed tls.cert remains available as fallback.
|
|
||||||
tlsCfg.GetCertificate = getCertificate
|
|
||||||
|
|
||||||
return cleanUp
|
|
||||||
}
|
|
||||||
|
|
||||||
// fileExists reports whether the named file or directory exists.
|
// fileExists reports whether the named file or directory exists.
|
||||||
// This function is taken from https://github.com/btcsuite/btcd
|
// This function is taken from https://github.com/btcsuite/btcd
|
||||||
func fileExists(name string) bool {
|
func fileExists(name string) bool {
|
||||||
|
141
server_test.go
141
server_test.go
@@ -4,152 +4,11 @@
|
|||||||
package lnd
|
package lnd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/lncfg"
|
"github.com/lightningnetwork/lnd/lncfg"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestTLSAutoRegeneration creates an expired TLS certificate, to test that a
|
|
||||||
// new TLS certificate pair is regenerated when the old pair expires. This is
|
|
||||||
// necessary because the pair expires after a little over a year.
|
|
||||||
func TestTLSAutoRegeneration(t *testing.T) {
|
|
||||||
tempDirPath := t.TempDir()
|
|
||||||
|
|
||||||
certPath := tempDirPath + "/tls.cert"
|
|
||||||
keyPath := tempDirPath + "/tls.key"
|
|
||||||
|
|
||||||
certDerBytes, keyBytes := genExpiredCertPair(t, tempDirPath)
|
|
||||||
expiredCert, err := x509.ParseCertificate(certDerBytes)
|
|
||||||
require.NoError(t, err, "failed to parse certificate")
|
|
||||||
|
|
||||||
certBuf := bytes.Buffer{}
|
|
||||||
err = pem.Encode(
|
|
||||||
&certBuf, &pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: certDerBytes,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to encode certificate")
|
|
||||||
|
|
||||||
keyBuf := bytes.Buffer{}
|
|
||||||
err = pem.Encode(
|
|
||||||
&keyBuf, &pem.Block{
|
|
||||||
Type: "EC PRIVATE KEY",
|
|
||||||
Bytes: keyBytes,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to encode private key")
|
|
||||||
|
|
||||||
// Write cert and key files.
|
|
||||||
err = ioutil.WriteFile(tempDirPath+"/tls.cert", certBuf.Bytes(), 0644)
|
|
||||||
require.NoError(t, err, "failed to write cert file")
|
|
||||||
err = ioutil.WriteFile(tempDirPath+"/tls.key", keyBuf.Bytes(), 0600)
|
|
||||||
require.NoError(t, err, "failed to write key file")
|
|
||||||
|
|
||||||
rpcListener := net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""}
|
|
||||||
rpcListeners := make([]net.Addr, 0)
|
|
||||||
rpcListeners = append(rpcListeners, &rpcListener)
|
|
||||||
|
|
||||||
// Now let's run getTLSConfig. If it works properly, it should delete
|
|
||||||
// the cert and create a new one.
|
|
||||||
cfg := &Config{
|
|
||||||
TLSCertPath: certPath,
|
|
||||||
TLSKeyPath: keyPath,
|
|
||||||
TLSCertDuration: 42 * time.Hour,
|
|
||||||
RPCListeners: rpcListeners,
|
|
||||||
}
|
|
||||||
tlsManager := NewTLSManager(cfg)
|
|
||||||
_, _, _, cleanUp, err := tlsManager.getConfig()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't retrieve TLS config")
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
// Grab the certificate to test that getTLSConfig did its job correctly
|
|
||||||
// and generated a new cert.
|
|
||||||
newCertData, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't grab new certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
newCert, err := x509.ParseCertificate(newCertData.Certificate[0])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse new certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the expired certificate was successfully deleted and
|
|
||||||
// replaced with a new one.
|
|
||||||
if !newCert.NotAfter.After(expiredCert.NotAfter) {
|
|
||||||
t.Fatalf("New certificate expiration is too old")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// genExpiredCertPair generates an expired key/cert pair to test that expired
|
|
||||||
// certificates are being regenerated correctly.
|
|
||||||
func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
|
||||||
// Max serial number.
|
|
||||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
||||||
|
|
||||||
// Generate a serial number that's below the serialNumberLimit.
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
||||||
require.NoError(t, err, "failed to generate serial number")
|
|
||||||
|
|
||||||
host := "lightning"
|
|
||||||
|
|
||||||
// Create a simple ip address for the fake certificate.
|
|
||||||
ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
|
|
||||||
|
|
||||||
dnsNames := []string{host, "unix", "unixpacket"}
|
|
||||||
|
|
||||||
// Construct the certificate template.
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"lnd autogenerated cert"},
|
|
||||||
CommonName: host,
|
|
||||||
},
|
|
||||||
NotBefore: time.Now().Add(-time.Hour * 24),
|
|
||||||
NotAfter: time.Now(),
|
|
||||||
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment |
|
|
||||||
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
||||||
IsCA: true, // so can sign self.
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
|
|
||||||
DNSNames: dnsNames,
|
|
||||||
IPAddresses: ipAddresses,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a private key for the certificate.
|
|
||||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to generate a private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
certDerBytes, err := x509.CreateCertificate(
|
|
||||||
rand.Reader, &template, &template, &priv.PublicKey, priv,
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to create certificate")
|
|
||||||
|
|
||||||
keyBytes, err := x509.MarshalECPrivateKey(priv)
|
|
||||||
require.NoError(t, err, "unable to encode privkey")
|
|
||||||
|
|
||||||
return certDerBytes, keyBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestShouldPeerBootstrap tests that we properly skip network bootstrap for
|
// TestShouldPeerBootstrap tests that we properly skip network bootstrap for
|
||||||
// the developer networks, and also if bootstrapping is explicitly disabled.
|
// the developer networks, and also if bootstrapping is explicitly disabled.
|
||||||
func TestShouldPeerBootstrap(t *testing.T) {
|
func TestShouldPeerBootstrap(t *testing.T) {
|
||||||
|
273
tls_manager.go
Normal file
273
tls_manager.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package lnd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/cert"
|
||||||
|
"github.com/lightningnetwork/lnd/lncfg"
|
||||||
|
"github.com/lightningnetwork/lnd/lnrpc"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
letsEncryptTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// TLSManager generates/renews a TLS certificate when needed and returns the
|
||||||
|
// certificate configuration options needed for gRPC and REST.
|
||||||
|
type TLSManager struct {
|
||||||
|
cfg *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLSManager returns a reference to a new TLSManager.
|
||||||
|
func NewTLSManager(cfg *Config) *TLSManager {
|
||||||
|
return &TLSManager{
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConfig returns a TLS configuration for the gRPC server and credentials
|
||||||
|
// and a proxy destination for the REST reverse proxy.
|
||||||
|
func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
|
||||||
|
func(net.Addr) (net.Listener, error), func(), error) {
|
||||||
|
|
||||||
|
tlsCfg, cleanUp, err := t.generateOrRenewCert()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we know that we have a ceritificate, let's generate the
|
||||||
|
// required config options.
|
||||||
|
restCreds, err := credentials.NewClientTLSFromFile(
|
||||||
|
t.cfg.TLSCertPath, "",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverCreds := credentials.NewTLS(tlsCfg)
|
||||||
|
serverOpts := []grpc.ServerOption{grpc.Creds(serverCreds)}
|
||||||
|
|
||||||
|
// For our REST dial options, we'll still use TLS, but also increase
|
||||||
|
// the max message size that we'll decode to allow clients to hit
|
||||||
|
// endpoints which return more data such as the DescribeGraph call.
|
||||||
|
// We set this to 200MiB atm. Should be the same value as maxMsgRecvSize
|
||||||
|
// in cmd/lncli/main.go.
|
||||||
|
restDialOpts := []grpc.DialOption{
|
||||||
|
grpc.WithTransportCredentials(restCreds),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(lnrpc.MaxGrpcMsgSize),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a function closure that can be used to listen on a given
|
||||||
|
// address with the current TLS config.
|
||||||
|
restListen := func(addr net.Addr) (net.Listener, error) {
|
||||||
|
// For restListen we will call ListenOnAddress if TLS is
|
||||||
|
// disabled.
|
||||||
|
if t.cfg.DisableRestTLS {
|
||||||
|
return lncfg.ListenOnAddress(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return lncfg.TLSListenOnAddress(addr, tlsCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverOpts, restDialOpts, restListen, cleanUp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateOrRenewCert generates a new TLS certificate if we're not using one
|
||||||
|
// yet or renews it if it's outdated.
|
||||||
|
func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) {
|
||||||
|
// Genereate a TLS pair if we don't have one yet.
|
||||||
|
err := t.createTLSPair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certData, parsedCert, err := cert.LoadCert(
|
||||||
|
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if the certificate needs to be renewed. If it does, we
|
||||||
|
// return the newly generated certificate data instead.
|
||||||
|
reloadedCertData, err := t.certMaintenance(parsedCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if reloadedCertData != nil {
|
||||||
|
certData = *reloadedCertData
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCfg := cert.TLSConfFromCert(certData)
|
||||||
|
|
||||||
|
cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg)
|
||||||
|
|
||||||
|
return tlsCfg, cleanUp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTLSPair creates and writes a TLS pair to disk if the pair doesn't
|
||||||
|
// exist yet. If the TLSEncryptKey setting is on, and a plaintext key is
|
||||||
|
// already written to disk, this function overwrites the plaintext key with
|
||||||
|
// the encrypted form.
|
||||||
|
func (t *TLSManager) createTLSPair() error {
|
||||||
|
// Ensure we create TLS key and certificate if they don't exist.
|
||||||
|
if fileExists(t.cfg.TLSCertPath) || fileExists(t.cfg.TLSKeyPath) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rpcsLog.Infof("Generating TLS certificates...")
|
||||||
|
err := cert.GenCertPair(
|
||||||
|
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
||||||
|
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
||||||
|
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
||||||
|
t.cfg.TLSCertDuration,
|
||||||
|
)
|
||||||
|
rpcsLog.Infof("Done generating TLS certificates")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// certMaintenance checks if the certificate IP and domains matches the config,
|
||||||
|
// and renews the certificate if either this data is outdated or the
|
||||||
|
// certificate is expired.
|
||||||
|
func (t *TLSManager) certMaintenance(
|
||||||
|
parsedCert *x509.Certificate) (*tls.Certificate, error) {
|
||||||
|
|
||||||
|
// We check whether the certificate we have on disk match the IPs and
|
||||||
|
// domains specified by the config. If the extra IPs or domains have
|
||||||
|
// changed from when the certificate was created, we will refresh the
|
||||||
|
// certificate if auto refresh is active.
|
||||||
|
refresh := false
|
||||||
|
var err error
|
||||||
|
if t.cfg.TLSAutoRefresh {
|
||||||
|
refresh, err = cert.IsOutdated(
|
||||||
|
parsedCert, t.cfg.TLSExtraIPs,
|
||||||
|
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the certificate expired or it was outdated, delete it and the TLS
|
||||||
|
// key and generate a new pair.
|
||||||
|
if !time.Now().After(parsedCert.NotAfter) && !refresh {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ltndLog.Info("TLS certificate is expired or outdated, " +
|
||||||
|
"generating a new one")
|
||||||
|
|
||||||
|
err = os.Remove(t.cfg.TLSCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Remove(t.cfg.TLSKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpcsLog.Infof("Renewing TLS certificates...")
|
||||||
|
err = cert.GenCertPair(
|
||||||
|
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
||||||
|
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
||||||
|
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
||||||
|
t.cfg.TLSCertDuration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rpcsLog.Infof("Done renewing TLS certificates")
|
||||||
|
|
||||||
|
// Reload the certificate data.
|
||||||
|
reloadedCertData, _, err := cert.LoadCert(
|
||||||
|
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &reloadedCertData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the
|
||||||
|
// option is set.
|
||||||
|
func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate,
|
||||||
|
tlsCfg *tls.Config) func() {
|
||||||
|
|
||||||
|
// If Let's Encrypt is enabled, instantiate autocert to request/renew
|
||||||
|
// the certificates.
|
||||||
|
cleanUp := func() {}
|
||||||
|
if t.cfg.LetsEncryptDomain == "" {
|
||||||
|
return cleanUp
|
||||||
|
}
|
||||||
|
|
||||||
|
ltndLog.Infof("Using Let's Encrypt certificate for domain %v",
|
||||||
|
t.cfg.LetsEncryptDomain)
|
||||||
|
|
||||||
|
manager := autocert.Manager{
|
||||||
|
Cache: autocert.DirCache(t.cfg.LetsEncryptDir),
|
||||||
|
Prompt: autocert.AcceptTOS,
|
||||||
|
HostPolicy: autocert.HostWhitelist(
|
||||||
|
t.cfg.LetsEncryptDomain,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: t.cfg.LetsEncryptListen,
|
||||||
|
Handler: manager.HTTPHandler(nil),
|
||||||
|
ReadHeaderTimeout: letsEncryptTimeout,
|
||||||
|
}
|
||||||
|
shutdownCompleted := make(chan struct{})
|
||||||
|
cleanUp = func() {
|
||||||
|
err := srv.Shutdown(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
ltndLog.Errorf("Autocert listener shutdown "+
|
||||||
|
" error: %v", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-shutdownCompleted
|
||||||
|
ltndLog.Infof("Autocert challenge listener stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ltndLog.Infof("Autocert challenge listener started "+
|
||||||
|
"at %v", t.cfg.LetsEncryptListen)
|
||||||
|
|
||||||
|
err := srv.ListenAndServe()
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
ltndLog.Errorf("autocert http: %v", err)
|
||||||
|
}
|
||||||
|
close(shutdownCompleted)
|
||||||
|
}()
|
||||||
|
|
||||||
|
getCertificate := func(h *tls.ClientHelloInfo) (
|
||||||
|
*tls.Certificate, error) {
|
||||||
|
|
||||||
|
lecert, err := manager.GetCertificate(h)
|
||||||
|
if err != nil {
|
||||||
|
ltndLog.Errorf("GetCertificate: %v", err)
|
||||||
|
return certData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return lecert, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The self-signed tls.cert remains available as fallback.
|
||||||
|
tlsCfg.GetCertificate = getCertificate
|
||||||
|
|
||||||
|
return cleanUp
|
||||||
|
}
|
150
tls_manager_test.go
Normal file
150
tls_manager_test.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package lnd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestTLSAutoRegeneration creates an expired TLS certificate, to test that a
|
||||||
|
// new TLS certificate pair is regenerated when the old pair expires. This is
|
||||||
|
// necessary because the pair expires after a little over a year.
|
||||||
|
func TestTLSAutoRegeneration(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempDirPath := t.TempDir()
|
||||||
|
|
||||||
|
certPath := tempDirPath + "/tls.cert"
|
||||||
|
keyPath := tempDirPath + "/tls.key"
|
||||||
|
|
||||||
|
certDerBytes, keyBytes := genExpiredCertPair(t, tempDirPath)
|
||||||
|
expiredCert, err := x509.ParseCertificate(certDerBytes)
|
||||||
|
require.NoError(t, err, "failed to parse certificate")
|
||||||
|
|
||||||
|
certBuf := bytes.Buffer{}
|
||||||
|
err = pem.Encode(
|
||||||
|
&certBuf, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certDerBytes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to encode certificate")
|
||||||
|
|
||||||
|
keyBuf := bytes.Buffer{}
|
||||||
|
err = pem.Encode(
|
||||||
|
&keyBuf, &pem.Block{
|
||||||
|
Type: "EC PRIVATE KEY",
|
||||||
|
Bytes: keyBytes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to encode private key")
|
||||||
|
|
||||||
|
// Write cert and key files.
|
||||||
|
err = ioutil.WriteFile(tempDirPath+"/tls.cert", certBuf.Bytes(), 0644)
|
||||||
|
require.NoError(t, err, "failed to write cert file")
|
||||||
|
err = ioutil.WriteFile(tempDirPath+"/tls.key", keyBuf.Bytes(), 0600)
|
||||||
|
require.NoError(t, err, "failed to write key file")
|
||||||
|
|
||||||
|
rpcListener := net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""}
|
||||||
|
rpcListeners := make([]net.Addr, 0)
|
||||||
|
rpcListeners = append(rpcListeners, &rpcListener)
|
||||||
|
|
||||||
|
// Now let's run getTLSConfig. If it works properly, it should delete
|
||||||
|
// the cert and create a new one.
|
||||||
|
cfg := &Config{
|
||||||
|
TLSCertPath: certPath,
|
||||||
|
TLSKeyPath: keyPath,
|
||||||
|
TLSCertDuration: 42 * time.Hour,
|
||||||
|
RPCListeners: rpcListeners,
|
||||||
|
}
|
||||||
|
tlsManager := NewTLSManager(cfg)
|
||||||
|
_, _, _, cleanUp, err := tlsManager.getConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("couldn't retrieve TLS config")
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
// Grab the certificate to test that getTLSConfig did its job correctly
|
||||||
|
// and generated a new cert.
|
||||||
|
newCertData, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("couldn't grab new certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
newCert, err := x509.ParseCertificate(newCertData.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("couldn't parse new certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the expired certificate was successfully deleted and
|
||||||
|
// replaced with a new one.
|
||||||
|
if !newCert.NotAfter.After(expiredCert.NotAfter) {
|
||||||
|
t.Fatalf("New certificate expiration is too old")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// genExpiredCertPair generates an expired key/cert pair to test that expired
|
||||||
|
// certificates are being regenerated correctly.
|
||||||
|
func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
||||||
|
t.Helper()
|
||||||
|
// Max serial number.
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||||
|
|
||||||
|
// Generate a serial number that's below the serialNumberLimit.
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
require.NoError(t, err, "failed to generate serial number")
|
||||||
|
|
||||||
|
host := "lightning"
|
||||||
|
|
||||||
|
// Create a simple ip address for the fake certificate.
|
||||||
|
ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
|
||||||
|
|
||||||
|
dnsNames := []string{host, "unix", "unixpacket"}
|
||||||
|
|
||||||
|
// Construct the certificate template.
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"lnd autogenerated cert"},
|
||||||
|
CommonName: host,
|
||||||
|
},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour * 24),
|
||||||
|
NotAfter: time.Now(),
|
||||||
|
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment |
|
||||||
|
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
IsCA: true, // so can sign self.
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
|
||||||
|
DNSNames: dnsNames,
|
||||||
|
IPAddresses: ipAddresses,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a private key for the certificate.
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate a private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
certDerBytes, err := x509.CreateCertificate(
|
||||||
|
rand.Reader, &template, &template, &priv.PublicKey, priv,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to create certificate")
|
||||||
|
|
||||||
|
keyBytes, err := x509.MarshalECPrivateKey(priv)
|
||||||
|
require.NoError(t, err, "unable to encode privkey")
|
||||||
|
|
||||||
|
return certDerBytes, keyBytes
|
||||||
|
}
|
Reference in New Issue
Block a user