mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-07-07 22:10:27 +02:00
lnd: Add ability to encrypt TLS key on disk
This commit is contained in:
@ -271,6 +271,7 @@ type Config struct {
|
||||
TLSAutoRefresh bool `long:"tlsautorefresh" description:"Re-generate TLS certificate and key if the IPs or domains are changed"`
|
||||
TLSDisableAutofill bool `long:"tlsdisableautofill" description:"Do not include the interface IPs or the system hostname in TLS certificate, use first --tlsextradomain as Common Name instead, if set"`
|
||||
TLSCertDuration time.Duration `long:"tlscertduration" description:"The duration for which the auto-generated TLS certificate will be valid for"`
|
||||
TLSEncryptKey bool `long:"tlsencryptkey" description:"Automatically encrypts the TLS private key and generates ephemeral TLS key pairs when the wallet is locked or not initialized"`
|
||||
|
||||
NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication, can only be used if server is not listening on a public interface."`
|
||||
AdminMacPath string `long:"adminmacaroonpath" description:"Path to write the admin macaroon for lnd's RPC and REST services if it doesn't exist"`
|
||||
|
2
go.mod
2
go.mod
@ -32,7 +32,7 @@ require (
|
||||
github.com/lightninglabs/neutrino v0.14.2
|
||||
github.com/lightninglabs/protobuf-hex-display v1.4.3-hex-display
|
||||
github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1
|
||||
github.com/lightningnetwork/lnd/cert v1.1.1
|
||||
github.com/lightningnetwork/lnd/cert v1.2.0
|
||||
github.com/lightningnetwork/lnd/clock v1.1.0
|
||||
github.com/lightningnetwork/lnd/healthcheck v1.2.2
|
||||
github.com/lightningnetwork/lnd/kvdb v1.3.1
|
||||
|
2
go.sum
2
go.sum
@ -447,6 +447,8 @@ github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1
|
||||
github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1/go.mod h1:7dDx73ApjEZA0kcknI799m2O5kkpfg4/gr7N092ojNo=
|
||||
github.com/lightningnetwork/lnd/cert v1.1.1 h1:Nsav0RlIDRbOnzz2Yu69SQlK939IKya3Q2S0mDviIN8=
|
||||
github.com/lightningnetwork/lnd/cert v1.1.1/go.mod h1:1P46svkkd73oSoeI4zjkVKgZNwGq8bkGuPR8z+5vQUs=
|
||||
github.com/lightningnetwork/lnd/cert v1.2.0 h1:IWfjHNMI5JgQZU5fdvDptF3DkVI38f4jO/s3tYgWFbE=
|
||||
github.com/lightningnetwork/lnd/cert v1.2.0/go.mod h1:04JhIEodoR6usBN5+XBRtLEEmEHsclLi0tEyxZQNP+w=
|
||||
github.com/lightningnetwork/lnd/clock v1.0.1/go.mod h1:KnQudQ6w0IAMZi1SgvecLZQZ43ra2vpDNj7H/aasemg=
|
||||
github.com/lightningnetwork/lnd/clock v1.1.0 h1:/yfVAwtPmdx45aQBoXQImeY7sOIEr7IXlImRMBOZ7GQ=
|
||||
github.com/lightningnetwork/lnd/clock v1.1.0/go.mod h1:KnQudQ6w0IAMZi1SgvecLZQZ43ra2vpDNj7H/aasemg=
|
||||
|
49
lnd.go
49
lnd.go
@ -213,15 +213,31 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
|
||||
return mkErr("error initializing DBs: %v", err)
|
||||
}
|
||||
|
||||
// Only process macaroons if --no-macaroons isn't set.
|
||||
tlsManager := NewTLSManager(cfg)
|
||||
serverOpts, restDialOpts, restListen, cleanUp,
|
||||
err := tlsManager.getConfig()
|
||||
tlsManagerCfg := &TLSManagerCfg{
|
||||
TLSCertPath: cfg.TLSCertPath,
|
||||
TLSKeyPath: cfg.TLSKeyPath,
|
||||
TLSEncryptKey: cfg.TLSEncryptKey,
|
||||
TLSExtraIPs: cfg.TLSExtraIPs,
|
||||
TLSExtraDomains: cfg.TLSExtraDomains,
|
||||
TLSAutoRefresh: cfg.TLSAutoRefresh,
|
||||
TLSDisableAutofill: cfg.TLSDisableAutofill,
|
||||
TLSCertDuration: cfg.TLSCertDuration,
|
||||
|
||||
if err != nil {
|
||||
return mkErr("unable to load TLS credentials: %v", err)
|
||||
LetsEncryptDir: cfg.LetsEncryptDir,
|
||||
LetsEncryptDomain: cfg.LetsEncryptDomain,
|
||||
LetsEncryptListen: cfg.LetsEncryptListen,
|
||||
|
||||
DisableRestTLS: cfg.DisableRestTLS,
|
||||
}
|
||||
tlsManager := NewTLSManager(tlsManagerCfg)
|
||||
serverOpts, restDialOpts, restListen, cleanUp,
|
||||
err := tlsManager.SetCertificateBeforeUnlock()
|
||||
if err != nil {
|
||||
return mkErr("error setting cert before unlock: %v", err)
|
||||
}
|
||||
if cleanUp != nil {
|
||||
defer cleanUp()
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// If we have chosen to start with a dedicated listener for the
|
||||
// rpc server, we set it directly.
|
||||
@ -512,7 +528,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
|
||||
server, err := newServer(
|
||||
cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc,
|
||||
activeChainControl.Cfg.WalletUnlockParams.ChansToRestore,
|
||||
multiAcceptor, torController,
|
||||
multiAcceptor, torController, tlsManager,
|
||||
)
|
||||
if err != nil {
|
||||
return mkErr("unable to create server: %v", err)
|
||||
@ -538,6 +554,12 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
|
||||
}
|
||||
defer atplManager.Stop()
|
||||
|
||||
err = tlsManager.LoadPermanentCertificate(activeChainControl.KeyRing)
|
||||
if err != nil {
|
||||
return mkErr("unable to load permanent TLS certificate: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
// Now we have created all dependencies necessary to populate and
|
||||
// start the RPC server.
|
||||
err = rpcServer.addDeps(
|
||||
@ -629,17 +651,6 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
|
||||
return nil
|
||||
}
|
||||
|
||||
// fileExists reports whether the named file or directory exists.
|
||||
// This function is taken from https://github.com/btcsuite/btcd
|
||||
func fileExists(name string) bool {
|
||||
if _, err := os.Stat(name); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// bakeMacaroon creates a new macaroon with newest version and the given
|
||||
// permissions then returns it binary serialized.
|
||||
func bakeMacaroon(ctx context.Context, svc *macaroons.Service,
|
||||
|
19
server.go
19
server.go
@ -27,7 +27,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/aliasmgr"
|
||||
"github.com/lightningnetwork/lnd/autopilot"
|
||||
"github.com/lightningnetwork/lnd/brontide"
|
||||
"github.com/lightningnetwork/lnd/cert"
|
||||
"github.com/lightningnetwork/lnd/chainreg"
|
||||
"github.com/lightningnetwork/lnd/chanacceptor"
|
||||
"github.com/lightningnetwork/lnd/chanbackup"
|
||||
@ -294,6 +293,8 @@ type server struct {
|
||||
|
||||
readPool *pool.Read
|
||||
|
||||
tlsManager *TLSManager
|
||||
|
||||
// featureMgr dispatches feature vectors for various contexts within the
|
||||
// daemon.
|
||||
featureMgr *feature.Manager
|
||||
@ -473,7 +474,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
nodeKeyDesc *keychain.KeyDescriptor,
|
||||
chansToRestore walletunlocker.ChannelsToRecover,
|
||||
chanPredicate chanacceptor.ChannelAcceptor,
|
||||
torController *tor.Controller) (*server, error) {
|
||||
torController *tor.Controller, tlsManager *TLSManager) (*server,
|
||||
error) {
|
||||
|
||||
var (
|
||||
err error
|
||||
@ -600,6 +602,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
|
||||
customMessageServer: subscribe.NewServer(),
|
||||
|
||||
tlsManager: tlsManager,
|
||||
|
||||
featureMgr: featureMgr,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
@ -1640,18 +1644,15 @@ func (s *server) createLivenessMonitor(cfg *Config, cc *chainreg.ChainControl) {
|
||||
tlsHealthCheck := healthcheck.NewObservation(
|
||||
"tls",
|
||||
func() error {
|
||||
_, parsedCert, err := cert.LoadCert(
|
||||
cfg.TLSCertPath, cfg.TLSKeyPath,
|
||||
expired, expTime, err := s.tlsManager.IsCertExpired(
|
||||
s.cc.KeyRing,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the current time is passed the certificate's
|
||||
// expiry time, then it is considered expired
|
||||
if time.Now().After(parsedCert.NotAfter) {
|
||||
if expired {
|
||||
return fmt.Errorf("TLS certificate is "+
|
||||
"expired as of %v", parsedCert.NotAfter)
|
||||
"expired as of %v", expTime)
|
||||
}
|
||||
|
||||
// If the certificate is not outdated, no error needs
|
||||
|
442
tls_manager.go
442
tls_manager.go
@ -1,34 +1,88 @@
|
||||
package lnd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/cert"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lncfg"
|
||||
"github.com/lightningnetwork/lnd/lnencrypt"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// modifyFilePermissons is the file permission used for writing
|
||||
// encrypted tls files.
|
||||
modifyFilePermissions = 0600
|
||||
|
||||
// validityHours is the number of hours the ephemeral tls certificate
|
||||
// will be valid, if encrypting tls certificates is turned on.
|
||||
validityHours = 24
|
||||
)
|
||||
|
||||
var (
|
||||
// privateKeyPrefix is the prefix to a plaintext TLS key.
|
||||
privateKeyPrefix = []byte("-----BEGIN EC PRIVATE KEY-----")
|
||||
|
||||
// letsEncryptTimeout sets a timeout for the Lets Encrypt server.
|
||||
letsEncryptTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// TLSManager generates/renews a TLS certificate when needed and returns the
|
||||
// certificate configuration options needed for gRPC and REST.
|
||||
// TLSManagerCfg houses a set of values and methods that is passed to the
|
||||
// TLSManager for it to properly manage LND's TLS options.
|
||||
type TLSManagerCfg struct {
|
||||
TLSCertPath string
|
||||
TLSKeyPath string
|
||||
TLSEncryptKey bool
|
||||
TLSExtraIPs []string
|
||||
TLSExtraDomains []string
|
||||
TLSAutoRefresh bool
|
||||
TLSDisableAutofill bool
|
||||
TLSCertDuration time.Duration
|
||||
|
||||
LetsEncryptDir string
|
||||
LetsEncryptDomain string
|
||||
LetsEncryptListen string
|
||||
|
||||
DisableRestTLS bool
|
||||
}
|
||||
|
||||
// TLSManager generates/renews a TLS cert/key pair when needed. When required,
|
||||
// it encrypts the TLS key. It also returns the certificate configuration
|
||||
// options needed for gRPC and REST.
|
||||
type TLSManager struct {
|
||||
cfg *Config
|
||||
cfg *TLSManagerCfg
|
||||
|
||||
// tlsReloader is able to reload the certificate with the
|
||||
// GetCertificate function. In getConfig, tlsCfg.GetCertificate is
|
||||
// pointed towards t.tlsReloader.GetCertificateFunc(). When
|
||||
// TLSReloader's AttemptReload is called, the cert that tlsReloader
|
||||
// holds is changed, in turn changing the cert data
|
||||
// tlsCfg.GetCertificate will return.
|
||||
tlsReloader *cert.TLSReloader
|
||||
|
||||
// These options are only used if we're currently using an ephemeral
|
||||
// TLS certificate, used when we're encrypting the TLS key.
|
||||
ephemeralKey []byte
|
||||
ephemeralCert []byte
|
||||
ephemeralCertPath string
|
||||
}
|
||||
|
||||
// NewTLSManager returns a reference to a new TLSManager.
|
||||
func NewTLSManager(cfg *Config) *TLSManager {
|
||||
func NewTLSManager(cfg *TLSManagerCfg) *TLSManager {
|
||||
return &TLSManager{
|
||||
cfg: cfg,
|
||||
}
|
||||
@ -37,20 +91,54 @@ func NewTLSManager(cfg *Config) *TLSManager {
|
||||
// getConfig returns a TLS configuration for the gRPC server and credentials
|
||||
// and a proxy destination for the REST reverse proxy.
|
||||
func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
|
||||
func(net.Addr) (net.Listener, error), func(), error) {
|
||||
func(net.Addr) (net.Listener, error), error) {
|
||||
|
||||
tlsCfg, cleanUp, err := t.generateOrRenewCert()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
var (
|
||||
keyBytes, certBytes []byte
|
||||
err error
|
||||
)
|
||||
if t.ephemeralKey != nil {
|
||||
keyBytes = t.ephemeralKey
|
||||
certBytes = t.ephemeralCert
|
||||
} else {
|
||||
certBytes, keyBytes, err = cert.GetCertBytesFromPath(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we know that we have a ceritificate, let's generate the
|
||||
certData, _, err := cert.LoadCertFromBytes(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if t.tlsReloader == nil {
|
||||
tlsr, err := cert.NewTLSReloader(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
t.tlsReloader = tlsr
|
||||
}
|
||||
|
||||
tlsCfg := cert.TLSConfFromCert(certData)
|
||||
tlsCfg.GetCertificate = t.tlsReloader.GetCertificateFunc()
|
||||
|
||||
// If we're using the ephemeral certificate, we need to use the
|
||||
// ephemeral cert path.
|
||||
certPath := t.cfg.TLSCertPath
|
||||
if t.ephemeralCertPath != "" {
|
||||
certPath = t.ephemeralCertPath
|
||||
}
|
||||
|
||||
// Now that we know that we have a certificate, let's generate the
|
||||
// required config options.
|
||||
restCreds, err := credentials.NewClientTLSFromFile(
|
||||
t.cfg.TLSCertPath, "",
|
||||
certPath, "",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
serverCreds := credentials.NewTLS(tlsCfg)
|
||||
@ -80,14 +168,15 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
|
||||
return lncfg.TLSListenOnAddress(addr, tlsCfg)
|
||||
}
|
||||
|
||||
return serverOpts, restDialOpts, restListen, cleanUp, nil
|
||||
return serverOpts, restDialOpts, restListen, nil
|
||||
}
|
||||
|
||||
// generateOrRenewCert generates a new TLS certificate if we're not using one
|
||||
// yet or renews it if it's outdated.
|
||||
func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) {
|
||||
// Genereate a TLS pair if we don't have one yet.
|
||||
err := t.createTLSPair()
|
||||
// Generete a TLS pair if we don't have one yet.
|
||||
var emptyKeyRing keychain.SecretKeyRing
|
||||
err := t.generateCertPair(emptyKeyRing)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -101,7 +190,7 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) {
|
||||
|
||||
// Check to see if the certificate needs to be renewed. If it does, we
|
||||
// return the newly generated certificate data instead.
|
||||
reloadedCertData, err := t.certMaintenance(parsedCert)
|
||||
reloadedCertData, err := t.maintainCert(parsedCert)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -110,38 +199,134 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) {
|
||||
}
|
||||
|
||||
tlsCfg := cert.TLSConfFromCert(certData)
|
||||
|
||||
cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg)
|
||||
|
||||
return tlsCfg, cleanUp, nil
|
||||
}
|
||||
|
||||
// createTLSPair creates and writes a TLS pair to disk if the pair doesn't
|
||||
// exist yet. If the TLSEncryptKey setting is on, and a plaintext key is
|
||||
// already written to disk, this function overwrites the plaintext key with
|
||||
// generateCertPair creates and writes a TLS pair to disk if the pair
|
||||
// doesn't exist yet. If the TLSEncryptKey setting is on, and a plaintext key
|
||||
// is already written to disk, this function overwrites the plaintext key with
|
||||
// the encrypted form.
|
||||
func (t *TLSManager) createTLSPair() error {
|
||||
func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error {
|
||||
// Ensure we create TLS key and certificate if they don't exist.
|
||||
if fileExists(t.cfg.TLSCertPath) || fileExists(t.cfg.TLSKeyPath) {
|
||||
return nil
|
||||
// Handle discrepencies related to the TLSEncryptKey setting.
|
||||
return t.ensureEncryption(keyRing)
|
||||
}
|
||||
|
||||
rpcsLog.Infof("Generating TLS certificates...")
|
||||
err := cert.GenCertPair(
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
||||
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
||||
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
||||
t.cfg.TLSCertDuration,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.cfg.TLSEncryptKey {
|
||||
var b bytes.Buffer
|
||||
e, err := lnencrypt.KeyRingEncrypter(keyRing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create "+
|
||||
"encrypt key %v", err)
|
||||
}
|
||||
|
||||
err = e.EncryptPayloadToWriter(
|
||||
keyBytes, &b,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyBytes = b.Bytes()
|
||||
}
|
||||
|
||||
err = cert.WriteCertPair(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes,
|
||||
)
|
||||
|
||||
rpcsLog.Infof("Done generating TLS certificates")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// certMaintenance checks if the certificate IP and domains matches the config,
|
||||
// ensureEncryption takes a look at a couple of things:
|
||||
// 1) If the TLS key is in plaintext, but TLSEncryptKey is set, we need to
|
||||
// encrypt the file and rewrite it to disk.
|
||||
// 2) On the flip side, if TLSEncryptKey is not set, but the key on disk
|
||||
// is encrypted, we need to error out and warn the user.
|
||||
func (t *TLSManager) ensureEncryption(keyRing keychain.SecretKeyRing) error {
|
||||
_, keyBytes, err := cert.GetCertBytesFromPath(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.cfg.TLSEncryptKey && bytes.HasPrefix(keyBytes, privateKeyPrefix) {
|
||||
var b bytes.Buffer
|
||||
e, err := lnencrypt.KeyRingEncrypter(keyRing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to generate encrypt key %w",
|
||||
err)
|
||||
}
|
||||
|
||||
err = e.EncryptPayloadToWriter(keyBytes, &b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(
|
||||
t.cfg.TLSKeyPath, b.Bytes(), modifyFilePermissions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the private key is encrypted but the user didn't pass
|
||||
// --tlsencryptkey we error out. This is because the wallet is not
|
||||
// unlocked yet and we don't have access to the keys yet for decryption.
|
||||
if !t.cfg.TLSEncryptKey && !bytes.HasPrefix(keyBytes,
|
||||
privateKeyPrefix) {
|
||||
|
||||
ltndLog.Errorf("The TLS private key is encrypted on disk.")
|
||||
|
||||
return errors.New("the TLS key is encrypted but the " +
|
||||
"--tlsencryptkey flag is not passed. Please either " +
|
||||
"restart lnd with the --tlsencryptkey flag or delete " +
|
||||
"the TLS files for regeneration")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptTLSKeyBytes decrypts the TLS key.
|
||||
func decryptTLSKeyBytes(keyRing keychain.SecretKeyRing,
|
||||
encryptedData []byte) ([]byte, error) {
|
||||
|
||||
reader := bytes.NewReader(encryptedData)
|
||||
encrypter, err := lnencrypt.KeyRingEncrypter(keyRing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := encrypter.DecryptPayloadFromReader(
|
||||
reader,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// maintainCert checks if the certificate IP and domains matches the config,
|
||||
// and renews the certificate if either this data is outdated or the
|
||||
// certificate is expired.
|
||||
func (t *TLSManager) certMaintenance(
|
||||
func (t *TLSManager) maintainCert(
|
||||
parsedCert *x509.Certificate) (*tls.Certificate, error) {
|
||||
|
||||
// We check whether the certificate we have on disk match the IPs and
|
||||
@ -180,26 +365,30 @@ func (t *TLSManager) certMaintenance(
|
||||
}
|
||||
|
||||
rpcsLog.Infof("Renewing TLS certificates...")
|
||||
err = cert.GenCertPair(
|
||||
"lnd autogenerated cert", t.cfg.TLSCertPath,
|
||||
t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs,
|
||||
t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill,
|
||||
t.cfg.TLSCertDuration,
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd autogenerated cert", t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains,
|
||||
t.cfg.TLSDisableAutofill, t.cfg.TLSCertDuration,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = cert.WriteCertPair(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpcsLog.Infof("Done renewing TLS certificates")
|
||||
|
||||
// Reload the certificate data.
|
||||
reloadedCertData, _, err := cert.LoadCert(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &reloadedCertData, nil
|
||||
return &reloadedCertData, err
|
||||
}
|
||||
|
||||
// setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the
|
||||
@ -271,3 +460,190 @@ func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate,
|
||||
|
||||
return cleanUp
|
||||
}
|
||||
|
||||
// SetCertificateBeforeUnlock takes care of loading the certificate before
|
||||
// the wallet is unlocked. If the TLSEncryptKey setting is on, we need to
|
||||
// generate an ephemeral certificate we're able to use until the wallet is
|
||||
// unlocked and a new TLS pair can be encrypted to disk. Otherwise we can
|
||||
// process the certificate normally.
|
||||
func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption,
|
||||
[]grpc.DialOption, func(net.Addr) (net.Listener, error), func(),
|
||||
error) {
|
||||
|
||||
var cleanUp func()
|
||||
if t.cfg.TLSEncryptKey {
|
||||
_, err := t.loadEphemeralCertificate()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("unable to load "+
|
||||
"ephemeral certificate: %v", err)
|
||||
}
|
||||
} else {
|
||||
_, cleanUpFunc, err := t.generateOrRenewCert()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("unable to "+
|
||||
"generate or renew TLS certificate: %v", err)
|
||||
}
|
||||
cleanUp = cleanUpFunc
|
||||
}
|
||||
|
||||
serverOpts, restDialOpts, restListen, err := t.getConfig()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("unable to load TLS "+
|
||||
"credentials: %v", err)
|
||||
}
|
||||
|
||||
return serverOpts, restDialOpts, restListen, cleanUp, nil
|
||||
}
|
||||
|
||||
// loadEphemeralCertificate creates and loads the ephemeral certificate which
|
||||
// is used temporarily for secure communications before the wallet is unlocked.
|
||||
func (t *TLSManager) loadEphemeralCertificate() ([]byte, error) {
|
||||
rpcsLog.Infof("Generating ephemeral TLS certificates...")
|
||||
|
||||
tmpValidity := validityHours * time.Hour
|
||||
// Append .tmp to the end of the cert for differentiation.
|
||||
tmpCertPath := t.cfg.TLSCertPath + ".tmp"
|
||||
|
||||
// Pass in a blank string for the key path so the
|
||||
// function doesn't write them to disk.
|
||||
certBytes, keyBytes, err := cert.GenCertPair(
|
||||
"lnd ephemeral autogenerated cert", tmpCertPath,
|
||||
"", t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains,
|
||||
t.cfg.TLSDisableAutofill, tmpValidity,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.setEphemeralSettings(keyBytes, certBytes, t.cfg.TLSCertPath+".tmp")
|
||||
|
||||
err = cert.WriteCertPair(tmpCertPath, "", certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpcsLog.Infof("Done generating ephemeral TLS certificates")
|
||||
|
||||
return keyBytes, nil
|
||||
}
|
||||
|
||||
// LoadPermanentCertificate deletes the ephemeral certificate file and
|
||||
// generates a new one with the real keyring.
|
||||
func (t *TLSManager) LoadPermanentCertificate(
|
||||
keyRing keychain.SecretKeyRing) error {
|
||||
|
||||
if !t.cfg.TLSEncryptKey {
|
||||
return nil
|
||||
}
|
||||
|
||||
tmpCertPath := t.cfg.TLSCertPath + ".tmp"
|
||||
err := os.Remove(tmpCertPath)
|
||||
if err != nil {
|
||||
ltndLog.Warn("Unable to delete temp cert at %v",
|
||||
tmpCertPath)
|
||||
}
|
||||
|
||||
err = t.generateCertPair(keyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certBytes, encryptedKeyBytes, err := cert.GetCertBytesFromPath(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(encryptedKeyBytes)
|
||||
e, err := lnencrypt.KeyRingEncrypter(keyRing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to generate encrypt key %w",
|
||||
err)
|
||||
}
|
||||
|
||||
keyBytes, err := e.DecryptPayloadFromReader(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Switch the server's TLS certificate to the persistent one. By
|
||||
// changing the cert data the TLSReloader points to,
|
||||
err = t.tlsReloader.AttemptReload(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.deleteEphemeralSettings()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setEphemeralSettings sets the TLSManager settings needed when an ephemeral
|
||||
// certificate is created.
|
||||
func (t *TLSManager) setEphemeralSettings(keyBytes, certBytes []byte,
|
||||
certPath string) {
|
||||
|
||||
t.ephemeralKey = keyBytes
|
||||
t.ephemeralCert = certBytes
|
||||
t.ephemeralCertPath = t.cfg.TLSCertPath + ".tmp"
|
||||
}
|
||||
|
||||
// deleteEphemeralSettings deletes the TLSManager ephemeral settings that are
|
||||
// no longer needed when the ephemeral certificate is deleted so the Manager
|
||||
// knows we're no longer using it.
|
||||
func (t *TLSManager) deleteEphemeralSettings() {
|
||||
t.ephemeralKey = nil
|
||||
t.ephemeralCert = nil
|
||||
t.ephemeralCertPath = ""
|
||||
}
|
||||
|
||||
// fileExists reports whether the named file or directory exists.
|
||||
// This function is taken from https://github.com/btcsuite/btcd
|
||||
func fileExists(name string) bool {
|
||||
if _, err := os.Stat(name); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsCertExpired checks if the current TLS certificate is expired.
|
||||
func (t *TLSManager) IsCertExpired(keyRing keychain.SecretKeyRing) (bool,
|
||||
time.Time, error) {
|
||||
|
||||
certBytes, keyBytes, err := cert.GetCertBytesFromPath(
|
||||
t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
|
||||
)
|
||||
if err != nil {
|
||||
return false, time.Time{}, err
|
||||
}
|
||||
|
||||
// If TLSEncryptKey is set, there are two states the
|
||||
// certificate can be in: ephemeral or permanent.
|
||||
// Retrieve the key depending on which state it is in.
|
||||
if t.ephemeralKey != nil {
|
||||
keyBytes = t.ephemeralKey
|
||||
} else if t.cfg.TLSEncryptKey {
|
||||
keyBytes, err = decryptTLSKeyBytes(keyRing, keyBytes)
|
||||
if err != nil {
|
||||
return false, time.Time{}, err
|
||||
}
|
||||
}
|
||||
|
||||
_, parsedCert, err := cert.LoadCertFromBytes(
|
||||
certBytes, keyBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return false, time.Time{}, err
|
||||
}
|
||||
|
||||
// If the current time is passed the certificate's
|
||||
// expiry time, then it is considered expired
|
||||
if time.Now().After(parsedCert.NotAfter) {
|
||||
return true, parsedCert.NotAfter, nil
|
||||
}
|
||||
|
||||
return false, parsedCert.NotAfter, nil
|
||||
}
|
||||
|
@ -15,90 +15,235 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/lightningnetwork/lnd/cert"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnencrypt"
|
||||
"github.com/lightningnetwork/lnd/lntest/channels"
|
||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestTLSAutoRegeneration creates an expired TLS certificate, to test that a
|
||||
const (
|
||||
testTLSCertDuration = 42 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
privKeyBytes = channels.AlicesPrivKey
|
||||
|
||||
privKey, _ = btcec.PrivKeyFromBytes(privKeyBytes)
|
||||
)
|
||||
|
||||
// TestGenerateOrRenewCert creates an expired TLS certificate, to test that a
|
||||
// new TLS certificate pair is regenerated when the old pair expires. This is
|
||||
// necessary because the pair expires after a little over a year.
|
||||
func TestTLSAutoRegeneration(t *testing.T) {
|
||||
func TestGenerateOrRenewCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDirPath := t.TempDir()
|
||||
|
||||
certPath := tempDirPath + "/tls.cert"
|
||||
keyPath := tempDirPath + "/tls.key"
|
||||
|
||||
certDerBytes, keyBytes := genExpiredCertPair(t, tempDirPath)
|
||||
expiredCert, err := x509.ParseCertificate(certDerBytes)
|
||||
require.NoError(t, err, "failed to parse certificate")
|
||||
|
||||
certBuf := bytes.Buffer{}
|
||||
err = pem.Encode(
|
||||
&certBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDerBytes,
|
||||
},
|
||||
// Write an expired certificate to disk.
|
||||
certPath, keyPath, expiredCert := writeTestCertFiles(
|
||||
t, true, false, nil,
|
||||
)
|
||||
require.NoError(t, err, "failed to encode certificate")
|
||||
|
||||
keyBuf := bytes.Buffer{}
|
||||
err = pem.Encode(
|
||||
&keyBuf, &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err, "failed to encode private key")
|
||||
|
||||
// Write cert and key files.
|
||||
err = ioutil.WriteFile(tempDirPath+"/tls.cert", certBuf.Bytes(), 0644)
|
||||
require.NoError(t, err, "failed to write cert file")
|
||||
err = ioutil.WriteFile(tempDirPath+"/tls.key", keyBuf.Bytes(), 0600)
|
||||
require.NoError(t, err, "failed to write key file")
|
||||
|
||||
rpcListener := net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""}
|
||||
rpcListeners := make([]net.Addr, 0)
|
||||
rpcListeners = append(rpcListeners, &rpcListener)
|
||||
|
||||
// Now let's run getTLSConfig. If it works properly, it should delete
|
||||
// the cert and create a new one.
|
||||
cfg := &Config{
|
||||
// Now let's run the TLSManager's getConfig. If it works properly, it
|
||||
// should delete the cert and create a new one.
|
||||
cfg := &TLSManagerCfg{
|
||||
TLSCertPath: certPath,
|
||||
TLSKeyPath: keyPath,
|
||||
TLSCertDuration: 42 * time.Hour,
|
||||
RPCListeners: rpcListeners,
|
||||
TLSCertDuration: testTLSCertDuration,
|
||||
}
|
||||
tlsManager := NewTLSManager(cfg)
|
||||
_, _, _, cleanUp, err := tlsManager.getConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't retrieve TLS config")
|
||||
}
|
||||
_, cleanUp, err := tlsManager.generateOrRenewCert()
|
||||
require.NoError(t, err)
|
||||
_, _, _, err = tlsManager.getConfig()
|
||||
require.NoError(t, err, "couldn't retrieve TLS config")
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
// Grab the certificate to test that getTLSConfig did its job correctly
|
||||
// and generated a new cert.
|
||||
newCertData, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't grab new certificate")
|
||||
}
|
||||
require.NoError(t, err, "couldn't grab new certificate")
|
||||
|
||||
newCert, err := x509.ParseCertificate(newCertData.Certificate[0])
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse new certificate")
|
||||
}
|
||||
require.NoError(t, err, "couldn't parse new certificate")
|
||||
|
||||
// Check that the expired certificate was successfully deleted and
|
||||
// replaced with a new one.
|
||||
if !newCert.NotAfter.After(expiredCert.NotAfter) {
|
||||
t.Fatalf("New certificate expiration is too old")
|
||||
}
|
||||
require.True(t, newCert.NotAfter.After(expiredCert.NotAfter),
|
||||
"New certificate expiration is too old")
|
||||
}
|
||||
|
||||
// genExpiredCertPair generates an expired key/cert pair to test that expired
|
||||
// certificates are being regenerated correctly.
|
||||
func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
||||
// TestTLSManagerGenCert tests that the new TLS Manager loads correctly,
|
||||
// whether the encrypted TLS key flag is set or not.
|
||||
func TestTLSManagerGenCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, certPath, keyPath := newTestDirectory(t)
|
||||
|
||||
cfg := &TLSManagerCfg{
|
||||
TLSCertPath: certPath,
|
||||
TLSKeyPath: keyPath,
|
||||
}
|
||||
tlsManager := NewTLSManager(cfg)
|
||||
|
||||
_, _, err := tlsManager.generateOrRenewCert()
|
||||
require.NoError(t, err, "failed to generate new certificate")
|
||||
|
||||
// After this is run, a new certificate should be created and written
|
||||
// to disk. Since the TLSEncryptKey flag isn't set, we should be able
|
||||
// to read it in plaintext from disk.
|
||||
_, keyBytes, err := cert.GetCertBytesFromPath(
|
||||
cfg.TLSCertPath, cfg.TLSKeyPath,
|
||||
)
|
||||
require.NoError(t, err, "unable to load certificate")
|
||||
require.True(t, bytes.HasPrefix(keyBytes, privateKeyPrefix),
|
||||
"key is encrypted, but shouldn't be")
|
||||
|
||||
// Now test that if the TLSEncryptKey flag is set, an encrypted key is
|
||||
// created and written to disk.
|
||||
_, certPath, keyPath = newTestDirectory(t)
|
||||
|
||||
cfg = &TLSManagerCfg{
|
||||
TLSEncryptKey: true,
|
||||
TLSCertPath: certPath,
|
||||
TLSKeyPath: keyPath,
|
||||
TLSCertDuration: testTLSCertDuration,
|
||||
}
|
||||
tlsManager = NewTLSManager(cfg)
|
||||
keyRing := &mock.SecretKeyRing{
|
||||
RootKey: privKey,
|
||||
}
|
||||
|
||||
err = tlsManager.generateCertPair(keyRing)
|
||||
require.NoError(t, err, "failed to generate new certificate")
|
||||
|
||||
_, keyBytes, err = cert.GetCertBytesFromPath(
|
||||
certPath, keyPath,
|
||||
)
|
||||
require.NoError(t, err, "unable to load certificate")
|
||||
require.False(t, bytes.HasPrefix(keyBytes, privateKeyPrefix),
|
||||
"key isn't encrypted, but should be")
|
||||
}
|
||||
|
||||
// TestEnsureEncryption tests that ensureEncryption does a couple of things:
|
||||
// 1) If we have cfg.TLSEncryptKey set, but the tls file saved to disk is not
|
||||
// encrypted, generateOrRenewCert encrypts the file and rewrites it to disk.
|
||||
// 2) If cfg.TLSEncryptKey is not set, but the file *is* encrypted, then we
|
||||
// need to return an error to the user.
|
||||
func TestEnsureEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mock.SecretKeyRing{
|
||||
RootKey: privKey,
|
||||
}
|
||||
|
||||
// Write an unencrypted cert file to disk.
|
||||
certPath, keyPath, _ := writeTestCertFiles(
|
||||
t, false, false, keyRing,
|
||||
)
|
||||
|
||||
cfg := &TLSManagerCfg{
|
||||
TLSEncryptKey: true,
|
||||
TLSCertPath: certPath,
|
||||
TLSKeyPath: keyPath,
|
||||
}
|
||||
tlsManager := NewTLSManager(cfg)
|
||||
|
||||
// Check that the keyBytes are initially plaintext.
|
||||
_, newKeyBytes, err := cert.GetCertBytesFromPath(
|
||||
cfg.TLSCertPath, cfg.TLSKeyPath,
|
||||
)
|
||||
|
||||
require.NoError(t, err, "unable to load certificate files")
|
||||
require.True(t, bytes.HasPrefix(newKeyBytes, privateKeyPrefix),
|
||||
"key doesn't have correct plaintext prefix")
|
||||
|
||||
// ensureEncryption should detect that the TLS key is in plaintext,
|
||||
// encrypt it, and rewrite the encrypted version to disk.
|
||||
err = tlsManager.ensureEncryption(keyRing)
|
||||
require.NoError(t, err, "failed to generate new certificate")
|
||||
|
||||
// Grab the file from disk to check that the key is no longer
|
||||
// plaintext.
|
||||
_, newKeyBytes, err = cert.GetCertBytesFromPath(
|
||||
cfg.TLSCertPath, cfg.TLSKeyPath,
|
||||
)
|
||||
require.NoError(t, err, "unable to load certificate")
|
||||
require.False(t, bytes.HasPrefix(newKeyBytes, privateKeyPrefix),
|
||||
"key isn't encrypted, but should be")
|
||||
|
||||
// Now let's flip the cfg.TLSEncryptKey to false. Since the key on file
|
||||
// is encrypted, ensureEncryption should error out.
|
||||
tlsManager.cfg.TLSEncryptKey = false
|
||||
err = tlsManager.ensureEncryption(keyRing)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestGenerateEphemeralCert tests that an ephemeral certificate is created and
|
||||
// stored to disk in a .tmp file and that LoadPermanentCertificate deletes
|
||||
// file and replaces it with a fresh certificate pair.
|
||||
func TestGenerateEphemeralCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, certPath, keyPath := newTestDirectory(t)
|
||||
tmpCertPath := certPath + ".tmp"
|
||||
|
||||
cfg := &TLSManagerCfg{
|
||||
TLSCertPath: certPath,
|
||||
TLSKeyPath: keyPath,
|
||||
TLSEncryptKey: true,
|
||||
TLSCertDuration: testTLSCertDuration,
|
||||
}
|
||||
tlsManager := NewTLSManager(cfg)
|
||||
|
||||
keyBytes, err := tlsManager.loadEphemeralCertificate()
|
||||
require.NoError(t, err, "failed to generate new certificate")
|
||||
|
||||
certBytes, err := ioutil.ReadFile(tmpCertPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsr, err := cert.NewTLSReloader(certBytes, keyBytes)
|
||||
require.NoError(t, err)
|
||||
tlsManager.tlsReloader = tlsr
|
||||
|
||||
// Make sure .tmp file is created at the tmp cert path.
|
||||
_, err = ioutil.ReadFile(tmpCertPath)
|
||||
require.NoError(t, err, "couldn't find temp cert file")
|
||||
|
||||
// But no key should be stored.
|
||||
_, err = ioutil.ReadFile(cfg.TLSKeyPath)
|
||||
require.Error(t, err, "shouldn't have found file")
|
||||
|
||||
// And no permanent cert file should be stored.
|
||||
_, err = ioutil.ReadFile(cfg.TLSCertPath)
|
||||
require.Error(t, err, "shouldn't have found a permanent cert file")
|
||||
|
||||
// Now test that when we reload the certificate it generates the new
|
||||
// certificate properly.
|
||||
keyRing := &mock.SecretKeyRing{
|
||||
RootKey: privKey,
|
||||
}
|
||||
err = tlsManager.LoadPermanentCertificate(keyRing)
|
||||
require.NoError(t, err, "unable to reload certificate")
|
||||
|
||||
// Make sure .tmp file is deleted.
|
||||
_, _, err = cert.GetCertBytesFromPath(
|
||||
tmpCertPath, cfg.TLSKeyPath,
|
||||
)
|
||||
require.Error(t, err, ".tmp file should have been deleted")
|
||||
|
||||
// Make sure a certificate now exists at the permanent cert path.
|
||||
_, _, err = cert.GetCertBytesFromPath(
|
||||
cfg.TLSCertPath, cfg.TLSKeyPath,
|
||||
)
|
||||
require.NoError(t, err, "error loading permanent certificate")
|
||||
}
|
||||
|
||||
// genCertPair generates a key/cert pair, with the option of generating expired
|
||||
// certificates to make sure they are being regenerated correctly.
|
||||
func genCertPair(t *testing.T, expired bool) ([]byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
// Max serial number.
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
|
||||
@ -113,6 +258,15 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
||||
|
||||
dnsNames := []string{host, "unix", "unixpacket"}
|
||||
|
||||
var notBefore, notAfter time.Time
|
||||
if expired {
|
||||
notBefore = time.Now().Add(-time.Hour * 24)
|
||||
notAfter = time.Now()
|
||||
} else {
|
||||
notBefore = time.Now()
|
||||
notAfter = time.Now().Add(time.Hour * 24)
|
||||
}
|
||||
|
||||
// Construct the certificate template.
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
@ -120,16 +274,14 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
||||
Organization: []string{"lnd autogenerated cert"},
|
||||
CommonName: host,
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour * 24),
|
||||
NotAfter: time.Now(),
|
||||
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment |
|
||||
x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
IsCA: true, // so can sign self.
|
||||
BasicConstraintsValid: true,
|
||||
|
||||
DNSNames: dnsNames,
|
||||
IPAddresses: ipAddresses,
|
||||
DNSNames: dnsNames,
|
||||
IPAddresses: ipAddresses,
|
||||
}
|
||||
|
||||
// Generate a private key for the certificate.
|
||||
@ -148,3 +300,72 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) {
|
||||
|
||||
return certDerBytes, keyBytes
|
||||
}
|
||||
|
||||
// writeTestCertFiles creates test files and writes them to a temporary testing
|
||||
// directory.
|
||||
func writeTestCertFiles(t *testing.T, expiredCert, encryptTLSKey bool,
|
||||
keyRing keychain.KeyRing) (string, string, *x509.Certificate) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
tempDir, certPath, keyPath := newTestDirectory(t)
|
||||
|
||||
var certDerBytes, keyBytes []byte
|
||||
// Either create a valid certificate or an expired certificate pair,
|
||||
// depending on the test.
|
||||
if expiredCert {
|
||||
certDerBytes, keyBytes = genCertPair(t, true)
|
||||
} else {
|
||||
certDerBytes, keyBytes = genCertPair(t, false)
|
||||
}
|
||||
|
||||
parsedCert, err := x509.ParseCertificate(certDerBytes)
|
||||
require.NoError(t, err, "failed to parse certificate")
|
||||
|
||||
certBuf := bytes.Buffer{}
|
||||
err = pem.Encode(
|
||||
&certBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDerBytes,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err, "failed to encode certificate")
|
||||
|
||||
var keyBuf *bytes.Buffer
|
||||
if !encryptTLSKey {
|
||||
keyBuf = &bytes.Buffer{}
|
||||
err = pem.Encode(
|
||||
keyBuf, &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err, "failed to encode private key")
|
||||
} else {
|
||||
e, err := lnencrypt.KeyRingEncrypter(keyRing)
|
||||
require.NoError(t, err, "unable to generate key encrypter")
|
||||
err = e.EncryptPayloadToWriter(
|
||||
keyBytes, keyBuf,
|
||||
)
|
||||
require.NoError(t, err, "failed to encrypt private key")
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(tempDir+"/tls.cert", certBuf.Bytes(), 0644)
|
||||
require.NoError(t, err, "failed to write cert file")
|
||||
err = ioutil.WriteFile(tempDir+"/tls.key", keyBuf.Bytes(), 0600)
|
||||
require.NoError(t, err, "failed to write key file")
|
||||
|
||||
return certPath, keyPath, parsedCert
|
||||
}
|
||||
|
||||
// newTestDirectory creates a new test directory and returns the location of
|
||||
// the test tls.cert and tls.key files.
|
||||
func newTestDirectory(t *testing.T) (string, string, string) {
|
||||
t.Helper()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
certPath := tempDir + "/tls.cert"
|
||||
keyPath := tempDir + "/tls.key"
|
||||
|
||||
return tempDir, certPath, keyPath
|
||||
}
|
||||
|
Reference in New Issue
Block a user