diff --git a/config.go b/config.go index 39818397c..af49ce888 100644 --- a/config.go +++ b/config.go @@ -271,6 +271,7 @@ type Config struct { TLSAutoRefresh bool `long:"tlsautorefresh" description:"Re-generate TLS certificate and key if the IPs or domains are changed"` TLSDisableAutofill bool `long:"tlsdisableautofill" description:"Do not include the interface IPs or the system hostname in TLS certificate, use first --tlsextradomain as Common Name instead, if set"` TLSCertDuration time.Duration `long:"tlscertduration" description:"The duration for which the auto-generated TLS certificate will be valid for"` + TLSEncryptKey bool `long:"tlsencryptkey" description:"Automatically encrypts the TLS private key and generates ephemeral TLS key pairs when the wallet is locked or not initialized"` NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication, can only be used if server is not listening on a public interface."` AdminMacPath string `long:"adminmacaroonpath" description:"Path to write the admin macaroon for lnd's RPC and REST services if it doesn't exist"` diff --git a/go.mod b/go.mod index 8428c3c2b..e8d7c8bbe 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/lightninglabs/neutrino v0.14.2 github.com/lightninglabs/protobuf-hex-display v1.4.3-hex-display github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1 - github.com/lightningnetwork/lnd/cert v1.1.1 + github.com/lightningnetwork/lnd/cert v1.2.0 github.com/lightningnetwork/lnd/clock v1.1.0 github.com/lightningnetwork/lnd/healthcheck v1.2.2 github.com/lightningnetwork/lnd/kvdb v1.3.1 diff --git a/go.sum b/go.sum index 02a68c1fb..2458dcb91 100644 --- a/go.sum +++ b/go.sum @@ -447,6 +447,8 @@ github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1 github.com/lightningnetwork/lightning-onion v1.2.1-0.20221202012345-ca23184850a1/go.mod h1:7dDx73ApjEZA0kcknI799m2O5kkpfg4/gr7N092ojNo= github.com/lightningnetwork/lnd/cert v1.1.1 h1:Nsav0RlIDRbOnzz2Yu69SQlK939IKya3Q2S0mDviIN8= github.com/lightningnetwork/lnd/cert v1.1.1/go.mod h1:1P46svkkd73oSoeI4zjkVKgZNwGq8bkGuPR8z+5vQUs= +github.com/lightningnetwork/lnd/cert v1.2.0 h1:IWfjHNMI5JgQZU5fdvDptF3DkVI38f4jO/s3tYgWFbE= +github.com/lightningnetwork/lnd/cert v1.2.0/go.mod h1:04JhIEodoR6usBN5+XBRtLEEmEHsclLi0tEyxZQNP+w= github.com/lightningnetwork/lnd/clock v1.0.1/go.mod h1:KnQudQ6w0IAMZi1SgvecLZQZ43ra2vpDNj7H/aasemg= github.com/lightningnetwork/lnd/clock v1.1.0 h1:/yfVAwtPmdx45aQBoXQImeY7sOIEr7IXlImRMBOZ7GQ= github.com/lightningnetwork/lnd/clock v1.1.0/go.mod h1:KnQudQ6w0IAMZi1SgvecLZQZ43ra2vpDNj7H/aasemg= diff --git a/lnd.go b/lnd.go index 277a970fe..8da903e96 100644 --- a/lnd.go +++ b/lnd.go @@ -213,15 +213,31 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, return mkErr("error initializing DBs: %v", err) } - // Only process macaroons if --no-macaroons isn't set. - tlsManager := NewTLSManager(cfg) - serverOpts, restDialOpts, restListen, cleanUp, - err := tlsManager.getConfig() + tlsManagerCfg := &TLSManagerCfg{ + TLSCertPath: cfg.TLSCertPath, + TLSKeyPath: cfg.TLSKeyPath, + TLSEncryptKey: cfg.TLSEncryptKey, + TLSExtraIPs: cfg.TLSExtraIPs, + TLSExtraDomains: cfg.TLSExtraDomains, + TLSAutoRefresh: cfg.TLSAutoRefresh, + TLSDisableAutofill: cfg.TLSDisableAutofill, + TLSCertDuration: cfg.TLSCertDuration, - if err != nil { - return mkErr("unable to load TLS credentials: %v", err) + LetsEncryptDir: cfg.LetsEncryptDir, + LetsEncryptDomain: cfg.LetsEncryptDomain, + LetsEncryptListen: cfg.LetsEncryptListen, + + DisableRestTLS: cfg.DisableRestTLS, + } + tlsManager := NewTLSManager(tlsManagerCfg) + serverOpts, restDialOpts, restListen, cleanUp, + err := tlsManager.SetCertificateBeforeUnlock() + if err != nil { + return mkErr("error setting cert before unlock: %v", err) + } + if cleanUp != nil { + defer cleanUp() } - defer cleanUp() // If we have chosen to start with a dedicated listener for the // rpc server, we set it directly. @@ -512,7 +528,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, server, err := newServer( cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, - multiAcceptor, torController, + multiAcceptor, torController, tlsManager, ) if err != nil { return mkErr("unable to create server: %v", err) @@ -538,6 +554,12 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, } defer atplManager.Stop() + err = tlsManager.LoadPermanentCertificate(activeChainControl.KeyRing) + if err != nil { + return mkErr("unable to load permanent TLS certificate: %v", + err) + } + // Now we have created all dependencies necessary to populate and // start the RPC server. err = rpcServer.addDeps( @@ -629,17 +651,6 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, return nil } -// fileExists reports whether the named file or directory exists. -// This function is taken from https://github.com/btcsuite/btcd -func fileExists(name string) bool { - if _, err := os.Stat(name); err != nil { - if os.IsNotExist(err) { - return false - } - } - return true -} - // bakeMacaroon creates a new macaroon with newest version and the given // permissions then returns it binary serialized. func bakeMacaroon(ctx context.Context, svc *macaroons.Service, diff --git a/server.go b/server.go index 6a6183c1b..8a3d6fedc 100644 --- a/server.go +++ b/server.go @@ -27,7 +27,6 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" - "github.com/lightningnetwork/lnd/cert" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -294,6 +293,8 @@ type server struct { readPool *pool.Read + tlsManager *TLSManager + // featureMgr dispatches feature vectors for various contexts within the // daemon. featureMgr *feature.Manager @@ -473,7 +474,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, chanPredicate chanacceptor.ChannelAcceptor, - torController *tor.Controller) (*server, error) { + torController *tor.Controller, tlsManager *TLSManager) (*server, + error) { var ( err error @@ -600,6 +602,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, customMessageServer: subscribe.NewServer(), + tlsManager: tlsManager, + featureMgr: featureMgr, quit: make(chan struct{}), } @@ -1640,18 +1644,15 @@ func (s *server) createLivenessMonitor(cfg *Config, cc *chainreg.ChainControl) { tlsHealthCheck := healthcheck.NewObservation( "tls", func() error { - _, parsedCert, err := cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, + expired, expTime, err := s.tlsManager.IsCertExpired( + s.cc.KeyRing, ) if err != nil { return err } - - // If the current time is passed the certificate's - // expiry time, then it is considered expired - if time.Now().After(parsedCert.NotAfter) { + if expired { return fmt.Errorf("TLS certificate is "+ - "expired as of %v", parsedCert.NotAfter) + "expired as of %v", expTime) } // If the certificate is not outdated, no error needs diff --git a/tls_manager.go b/tls_manager.go index 85845f509..495991b51 100644 --- a/tls_manager.go +++ b/tls_manager.go @@ -1,34 +1,88 @@ package lnd import ( + "bytes" "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" + "io/ioutil" "net" "net/http" "os" "time" "github.com/lightningnetwork/lnd/cert" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnrpc" "golang.org/x/crypto/acme/autocert" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) +const ( + // modifyFilePermissons is the file permission used for writing + // encrypted tls files. + modifyFilePermissions = 0600 + + // validityHours is the number of hours the ephemeral tls certificate + // will be valid, if encrypting tls certificates is turned on. + validityHours = 24 +) + var ( + // privateKeyPrefix is the prefix to a plaintext TLS key. + privateKeyPrefix = []byte("-----BEGIN EC PRIVATE KEY-----") + + // letsEncryptTimeout sets a timeout for the Lets Encrypt server. letsEncryptTimeout = 5 * time.Second ) -// TLSManager generates/renews a TLS certificate when needed and returns the -// certificate configuration options needed for gRPC and REST. +// TLSManagerCfg houses a set of values and methods that is passed to the +// TLSManager for it to properly manage LND's TLS options. +type TLSManagerCfg struct { + TLSCertPath string + TLSKeyPath string + TLSEncryptKey bool + TLSExtraIPs []string + TLSExtraDomains []string + TLSAutoRefresh bool + TLSDisableAutofill bool + TLSCertDuration time.Duration + + LetsEncryptDir string + LetsEncryptDomain string + LetsEncryptListen string + + DisableRestTLS bool +} + +// TLSManager generates/renews a TLS cert/key pair when needed. When required, +// it encrypts the TLS key. It also returns the certificate configuration +// options needed for gRPC and REST. type TLSManager struct { - cfg *Config + cfg *TLSManagerCfg + + // tlsReloader is able to reload the certificate with the + // GetCertificate function. In getConfig, tlsCfg.GetCertificate is + // pointed towards t.tlsReloader.GetCertificateFunc(). When + // TLSReloader's AttemptReload is called, the cert that tlsReloader + // holds is changed, in turn changing the cert data + // tlsCfg.GetCertificate will return. + tlsReloader *cert.TLSReloader + + // These options are only used if we're currently using an ephemeral + // TLS certificate, used when we're encrypting the TLS key. + ephemeralKey []byte + ephemeralCert []byte + ephemeralCertPath string } // NewTLSManager returns a reference to a new TLSManager. -func NewTLSManager(cfg *Config) *TLSManager { +func NewTLSManager(cfg *TLSManagerCfg) *TLSManager { return &TLSManager{ cfg: cfg, } @@ -37,20 +91,54 @@ func NewTLSManager(cfg *Config) *TLSManager { // getConfig returns a TLS configuration for the gRPC server and credentials // and a proxy destination for the REST reverse proxy. func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, - func(net.Addr) (net.Listener, error), func(), error) { + func(net.Addr) (net.Listener, error), error) { - tlsCfg, cleanUp, err := t.generateOrRenewCert() - if err != nil { - return nil, nil, nil, nil, err + var ( + keyBytes, certBytes []byte + err error + ) + if t.ephemeralKey != nil { + keyBytes = t.ephemeralKey + certBytes = t.ephemeralCert + } else { + certBytes, keyBytes, err = cert.GetCertBytesFromPath( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return nil, nil, nil, err + } } - // Now that we know that we have a ceritificate, let's generate the + certData, _, err := cert.LoadCertFromBytes(certBytes, keyBytes) + if err != nil { + return nil, nil, nil, err + } + + if t.tlsReloader == nil { + tlsr, err := cert.NewTLSReloader(certBytes, keyBytes) + if err != nil { + return nil, nil, nil, err + } + t.tlsReloader = tlsr + } + + tlsCfg := cert.TLSConfFromCert(certData) + tlsCfg.GetCertificate = t.tlsReloader.GetCertificateFunc() + + // If we're using the ephemeral certificate, we need to use the + // ephemeral cert path. + certPath := t.cfg.TLSCertPath + if t.ephemeralCertPath != "" { + certPath = t.ephemeralCertPath + } + + // Now that we know that we have a certificate, let's generate the // required config options. restCreds, err := credentials.NewClientTLSFromFile( - t.cfg.TLSCertPath, "", + certPath, "", ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } serverCreds := credentials.NewTLS(tlsCfg) @@ -80,14 +168,15 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, return lncfg.TLSListenOnAddress(addr, tlsCfg) } - return serverOpts, restDialOpts, restListen, cleanUp, nil + return serverOpts, restDialOpts, restListen, nil } // generateOrRenewCert generates a new TLS certificate if we're not using one // yet or renews it if it's outdated. func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { - // Genereate a TLS pair if we don't have one yet. - err := t.createTLSPair() + // Generete a TLS pair if we don't have one yet. + var emptyKeyRing keychain.SecretKeyRing + err := t.generateCertPair(emptyKeyRing) if err != nil { return nil, nil, err } @@ -101,7 +190,7 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { // Check to see if the certificate needs to be renewed. If it does, we // return the newly generated certificate data instead. - reloadedCertData, err := t.certMaintenance(parsedCert) + reloadedCertData, err := t.maintainCert(parsedCert) if err != nil { return nil, nil, err } @@ -110,38 +199,134 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { } tlsCfg := cert.TLSConfFromCert(certData) - cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg) return tlsCfg, cleanUp, nil } -// createTLSPair creates and writes a TLS pair to disk if the pair doesn't -// exist yet. If the TLSEncryptKey setting is on, and a plaintext key is -// already written to disk, this function overwrites the plaintext key with +// generateCertPair creates and writes a TLS pair to disk if the pair +// doesn't exist yet. If the TLSEncryptKey setting is on, and a plaintext key +// is already written to disk, this function overwrites the plaintext key with // the encrypted form. -func (t *TLSManager) createTLSPair() error { +func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error { // Ensure we create TLS key and certificate if they don't exist. if fileExists(t.cfg.TLSCertPath) || fileExists(t.cfg.TLSKeyPath) { - return nil + // Handle discrepencies related to the TLSEncryptKey setting. + return t.ensureEncryption(keyRing) } rpcsLog.Infof("Generating TLS certificates...") - err := cert.GenCertPair( + certBytes, keyBytes, err := cert.GenCertPair( "lnd autogenerated cert", t.cfg.TLSCertPath, t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, t.cfg.TLSCertDuration, ) + if err != nil { + return err + } + + if t.cfg.TLSEncryptKey { + var b bytes.Buffer + e, err := lnencrypt.KeyRingEncrypter(keyRing) + if err != nil { + return fmt.Errorf("unable to create "+ + "encrypt key %v", err) + } + + err = e.EncryptPayloadToWriter( + keyBytes, &b, + ) + if err != nil { + return err + } + + keyBytes = b.Bytes() + } + + err = cert.WriteCertPair( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes, + ) + rpcsLog.Infof("Done generating TLS certificates") return err } -// certMaintenance checks if the certificate IP and domains matches the config, +// ensureEncryption takes a look at a couple of things: +// 1) If the TLS key is in plaintext, but TLSEncryptKey is set, we need to +// encrypt the file and rewrite it to disk. +// 2) On the flip side, if TLSEncryptKey is not set, but the key on disk +// is encrypted, we need to error out and warn the user. +func (t *TLSManager) ensureEncryption(keyRing keychain.SecretKeyRing) error { + _, keyBytes, err := cert.GetCertBytesFromPath( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return err + } + + if t.cfg.TLSEncryptKey && bytes.HasPrefix(keyBytes, privateKeyPrefix) { + var b bytes.Buffer + e, err := lnencrypt.KeyRingEncrypter(keyRing) + if err != nil { + return fmt.Errorf("unable to generate encrypt key %w", + err) + } + + err = e.EncryptPayloadToWriter(keyBytes, &b) + if err != nil { + return err + } + err = ioutil.WriteFile( + t.cfg.TLSKeyPath, b.Bytes(), modifyFilePermissions, + ) + if err != nil { + return err + } + } + + // If the private key is encrypted but the user didn't pass + // --tlsencryptkey we error out. This is because the wallet is not + // unlocked yet and we don't have access to the keys yet for decryption. + if !t.cfg.TLSEncryptKey && !bytes.HasPrefix(keyBytes, + privateKeyPrefix) { + + ltndLog.Errorf("The TLS private key is encrypted on disk.") + + return errors.New("the TLS key is encrypted but the " + + "--tlsencryptkey flag is not passed. Please either " + + "restart lnd with the --tlsencryptkey flag or delete " + + "the TLS files for regeneration") + } + + return nil +} + +// decryptTLSKeyBytes decrypts the TLS key. +func decryptTLSKeyBytes(keyRing keychain.SecretKeyRing, + encryptedData []byte) ([]byte, error) { + + reader := bytes.NewReader(encryptedData) + encrypter, err := lnencrypt.KeyRingEncrypter(keyRing) + if err != nil { + return nil, err + } + + plaintext, err := encrypter.DecryptPayloadFromReader( + reader, + ) + if err != nil { + return nil, err + } + + return plaintext, nil +} + +// maintainCert checks if the certificate IP and domains matches the config, // and renews the certificate if either this data is outdated or the // certificate is expired. -func (t *TLSManager) certMaintenance( +func (t *TLSManager) maintainCert( parsedCert *x509.Certificate) (*tls.Certificate, error) { // We check whether the certificate we have on disk match the IPs and @@ -180,26 +365,30 @@ func (t *TLSManager) certMaintenance( } rpcsLog.Infof("Renewing TLS certificates...") - err = cert.GenCertPair( - "lnd autogenerated cert", t.cfg.TLSCertPath, - t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs, - t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, - t.cfg.TLSCertDuration, + certBytes, keyBytes, err := cert.GenCertPair( + "lnd autogenerated cert", t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains, + t.cfg.TLSDisableAutofill, t.cfg.TLSCertDuration, ) if err != nil { return nil, err } + + err = cert.WriteCertPair( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes, + ) + if err != nil { + return nil, err + } + rpcsLog.Infof("Done renewing TLS certificates") // Reload the certificate data. reloadedCertData, _, err := cert.LoadCert( t.cfg.TLSCertPath, t.cfg.TLSKeyPath, ) - if err != nil { - return nil, err - } - return &reloadedCertData, nil + return &reloadedCertData, err } // setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the @@ -271,3 +460,190 @@ func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate, return cleanUp } + +// SetCertificateBeforeUnlock takes care of loading the certificate before +// the wallet is unlocked. If the TLSEncryptKey setting is on, we need to +// generate an ephemeral certificate we're able to use until the wallet is +// unlocked and a new TLS pair can be encrypted to disk. Otherwise we can +// process the certificate normally. +func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption, + []grpc.DialOption, func(net.Addr) (net.Listener, error), func(), + error) { + + var cleanUp func() + if t.cfg.TLSEncryptKey { + _, err := t.loadEphemeralCertificate() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("unable to load "+ + "ephemeral certificate: %v", err) + } + } else { + _, cleanUpFunc, err := t.generateOrRenewCert() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("unable to "+ + "generate or renew TLS certificate: %v", err) + } + cleanUp = cleanUpFunc + } + + serverOpts, restDialOpts, restListen, err := t.getConfig() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("unable to load TLS "+ + "credentials: %v", err) + } + + return serverOpts, restDialOpts, restListen, cleanUp, nil +} + +// loadEphemeralCertificate creates and loads the ephemeral certificate which +// is used temporarily for secure communications before the wallet is unlocked. +func (t *TLSManager) loadEphemeralCertificate() ([]byte, error) { + rpcsLog.Infof("Generating ephemeral TLS certificates...") + + tmpValidity := validityHours * time.Hour + // Append .tmp to the end of the cert for differentiation. + tmpCertPath := t.cfg.TLSCertPath + ".tmp" + + // Pass in a blank string for the key path so the + // function doesn't write them to disk. + certBytes, keyBytes, err := cert.GenCertPair( + "lnd ephemeral autogenerated cert", tmpCertPath, + "", t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains, + t.cfg.TLSDisableAutofill, tmpValidity, + ) + if err != nil { + return nil, err + } + t.setEphemeralSettings(keyBytes, certBytes, t.cfg.TLSCertPath+".tmp") + + err = cert.WriteCertPair(tmpCertPath, "", certBytes, keyBytes) + if err != nil { + return nil, err + } + + rpcsLog.Infof("Done generating ephemeral TLS certificates") + + return keyBytes, nil +} + +// LoadPermanentCertificate deletes the ephemeral certificate file and +// generates a new one with the real keyring. +func (t *TLSManager) LoadPermanentCertificate( + keyRing keychain.SecretKeyRing) error { + + if !t.cfg.TLSEncryptKey { + return nil + } + + tmpCertPath := t.cfg.TLSCertPath + ".tmp" + err := os.Remove(tmpCertPath) + if err != nil { + ltndLog.Warn("Unable to delete temp cert at %v", + tmpCertPath) + } + + err = t.generateCertPair(keyRing) + if err != nil { + return err + } + + certBytes, encryptedKeyBytes, err := cert.GetCertBytesFromPath( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return err + } + + reader := bytes.NewReader(encryptedKeyBytes) + e, err := lnencrypt.KeyRingEncrypter(keyRing) + if err != nil { + return fmt.Errorf("unable to generate encrypt key %w", + err) + } + + keyBytes, err := e.DecryptPayloadFromReader(reader) + if err != nil { + return err + } + + // Switch the server's TLS certificate to the persistent one. By + // changing the cert data the TLSReloader points to, + err = t.tlsReloader.AttemptReload(certBytes, keyBytes) + if err != nil { + return err + } + + t.deleteEphemeralSettings() + + return nil +} + +// setEphemeralSettings sets the TLSManager settings needed when an ephemeral +// certificate is created. +func (t *TLSManager) setEphemeralSettings(keyBytes, certBytes []byte, + certPath string) { + + t.ephemeralKey = keyBytes + t.ephemeralCert = certBytes + t.ephemeralCertPath = t.cfg.TLSCertPath + ".tmp" +} + +// deleteEphemeralSettings deletes the TLSManager ephemeral settings that are +// no longer needed when the ephemeral certificate is deleted so the Manager +// knows we're no longer using it. +func (t *TLSManager) deleteEphemeralSettings() { + t.ephemeralKey = nil + t.ephemeralCert = nil + t.ephemeralCertPath = "" +} + +// fileExists reports whether the named file or directory exists. +// This function is taken from https://github.com/btcsuite/btcd +func fileExists(name string) bool { + if _, err := os.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// IsCertExpired checks if the current TLS certificate is expired. +func (t *TLSManager) IsCertExpired(keyRing keychain.SecretKeyRing) (bool, + time.Time, error) { + + certBytes, keyBytes, err := cert.GetCertBytesFromPath( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return false, time.Time{}, err + } + + // If TLSEncryptKey is set, there are two states the + // certificate can be in: ephemeral or permanent. + // Retrieve the key depending on which state it is in. + if t.ephemeralKey != nil { + keyBytes = t.ephemeralKey + } else if t.cfg.TLSEncryptKey { + keyBytes, err = decryptTLSKeyBytes(keyRing, keyBytes) + if err != nil { + return false, time.Time{}, err + } + } + + _, parsedCert, err := cert.LoadCertFromBytes( + certBytes, keyBytes, + ) + if err != nil { + return false, time.Time{}, err + } + + // If the current time is passed the certificate's + // expiry time, then it is considered expired + if time.Now().After(parsedCert.NotAfter) { + return true, parsedCert.NotAfter, nil + } + + return false, parsedCert.NotAfter, nil +} diff --git a/tls_manager_test.go b/tls_manager_test.go index e2d9cdc31..fcd2eadbf 100644 --- a/tls_manager_test.go +++ b/tls_manager_test.go @@ -15,90 +15,235 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/cert" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnencrypt" + "github.com/lightningnetwork/lnd/lntest/channels" + "github.com/lightningnetwork/lnd/lntest/mock" "github.com/stretchr/testify/require" ) -// TestTLSAutoRegeneration creates an expired TLS certificate, to test that a +const ( + testTLSCertDuration = 42 * time.Hour +) + +var ( + privKeyBytes = channels.AlicesPrivKey + + privKey, _ = btcec.PrivKeyFromBytes(privKeyBytes) +) + +// TestGenerateOrRenewCert creates an expired TLS certificate, to test that a // new TLS certificate pair is regenerated when the old pair expires. This is // necessary because the pair expires after a little over a year. -func TestTLSAutoRegeneration(t *testing.T) { +func TestGenerateOrRenewCert(t *testing.T) { t.Parallel() - tempDirPath := t.TempDir() - - certPath := tempDirPath + "/tls.cert" - keyPath := tempDirPath + "/tls.key" - - certDerBytes, keyBytes := genExpiredCertPair(t, tempDirPath) - expiredCert, err := x509.ParseCertificate(certDerBytes) - require.NoError(t, err, "failed to parse certificate") - - certBuf := bytes.Buffer{} - err = pem.Encode( - &certBuf, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certDerBytes, - }, + // Write an expired certificate to disk. + certPath, keyPath, expiredCert := writeTestCertFiles( + t, true, false, nil, ) - require.NoError(t, err, "failed to encode certificate") - keyBuf := bytes.Buffer{} - err = pem.Encode( - &keyBuf, &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: keyBytes, - }, - ) - require.NoError(t, err, "failed to encode private key") - - // Write cert and key files. - err = ioutil.WriteFile(tempDirPath+"/tls.cert", certBuf.Bytes(), 0644) - require.NoError(t, err, "failed to write cert file") - err = ioutil.WriteFile(tempDirPath+"/tls.key", keyBuf.Bytes(), 0600) - require.NoError(t, err, "failed to write key file") - - rpcListener := net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""} - rpcListeners := make([]net.Addr, 0) - rpcListeners = append(rpcListeners, &rpcListener) - - // Now let's run getTLSConfig. If it works properly, it should delete - // the cert and create a new one. - cfg := &Config{ + // Now let's run the TLSManager's getConfig. If it works properly, it + // should delete the cert and create a new one. + cfg := &TLSManagerCfg{ TLSCertPath: certPath, TLSKeyPath: keyPath, - TLSCertDuration: 42 * time.Hour, - RPCListeners: rpcListeners, + TLSCertDuration: testTLSCertDuration, } tlsManager := NewTLSManager(cfg) - _, _, _, cleanUp, err := tlsManager.getConfig() - if err != nil { - t.Fatalf("couldn't retrieve TLS config") - } + _, cleanUp, err := tlsManager.generateOrRenewCert() + require.NoError(t, err) + _, _, _, err = tlsManager.getConfig() + require.NoError(t, err, "couldn't retrieve TLS config") t.Cleanup(cleanUp) // Grab the certificate to test that getTLSConfig did its job correctly // and generated a new cert. newCertData, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - t.Fatalf("couldn't grab new certificate") - } + require.NoError(t, err, "couldn't grab new certificate") newCert, err := x509.ParseCertificate(newCertData.Certificate[0]) - if err != nil { - t.Fatalf("couldn't parse new certificate") - } + require.NoError(t, err, "couldn't parse new certificate") // Check that the expired certificate was successfully deleted and // replaced with a new one. - if !newCert.NotAfter.After(expiredCert.NotAfter) { - t.Fatalf("New certificate expiration is too old") - } + require.True(t, newCert.NotAfter.After(expiredCert.NotAfter), + "New certificate expiration is too old") } -// genExpiredCertPair generates an expired key/cert pair to test that expired -// certificates are being regenerated correctly. -func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) { +// TestTLSManagerGenCert tests that the new TLS Manager loads correctly, +// whether the encrypted TLS key flag is set or not. +func TestTLSManagerGenCert(t *testing.T) { + t.Parallel() + + _, certPath, keyPath := newTestDirectory(t) + + cfg := &TLSManagerCfg{ + TLSCertPath: certPath, + TLSKeyPath: keyPath, + } + tlsManager := NewTLSManager(cfg) + + _, _, err := tlsManager.generateOrRenewCert() + require.NoError(t, err, "failed to generate new certificate") + + // After this is run, a new certificate should be created and written + // to disk. Since the TLSEncryptKey flag isn't set, we should be able + // to read it in plaintext from disk. + _, keyBytes, err := cert.GetCertBytesFromPath( + cfg.TLSCertPath, cfg.TLSKeyPath, + ) + require.NoError(t, err, "unable to load certificate") + require.True(t, bytes.HasPrefix(keyBytes, privateKeyPrefix), + "key is encrypted, but shouldn't be") + + // Now test that if the TLSEncryptKey flag is set, an encrypted key is + // created and written to disk. + _, certPath, keyPath = newTestDirectory(t) + + cfg = &TLSManagerCfg{ + TLSEncryptKey: true, + TLSCertPath: certPath, + TLSKeyPath: keyPath, + TLSCertDuration: testTLSCertDuration, + } + tlsManager = NewTLSManager(cfg) + keyRing := &mock.SecretKeyRing{ + RootKey: privKey, + } + + err = tlsManager.generateCertPair(keyRing) + require.NoError(t, err, "failed to generate new certificate") + + _, keyBytes, err = cert.GetCertBytesFromPath( + certPath, keyPath, + ) + require.NoError(t, err, "unable to load certificate") + require.False(t, bytes.HasPrefix(keyBytes, privateKeyPrefix), + "key isn't encrypted, but should be") +} + +// TestEnsureEncryption tests that ensureEncryption does a couple of things: +// 1) If we have cfg.TLSEncryptKey set, but the tls file saved to disk is not +// encrypted, generateOrRenewCert encrypts the file and rewrites it to disk. +// 2) If cfg.TLSEncryptKey is not set, but the file *is* encrypted, then we +// need to return an error to the user. +func TestEnsureEncryption(t *testing.T) { + t.Parallel() + + keyRing := &mock.SecretKeyRing{ + RootKey: privKey, + } + + // Write an unencrypted cert file to disk. + certPath, keyPath, _ := writeTestCertFiles( + t, false, false, keyRing, + ) + + cfg := &TLSManagerCfg{ + TLSEncryptKey: true, + TLSCertPath: certPath, + TLSKeyPath: keyPath, + } + tlsManager := NewTLSManager(cfg) + + // Check that the keyBytes are initially plaintext. + _, newKeyBytes, err := cert.GetCertBytesFromPath( + cfg.TLSCertPath, cfg.TLSKeyPath, + ) + + require.NoError(t, err, "unable to load certificate files") + require.True(t, bytes.HasPrefix(newKeyBytes, privateKeyPrefix), + "key doesn't have correct plaintext prefix") + + // ensureEncryption should detect that the TLS key is in plaintext, + // encrypt it, and rewrite the encrypted version to disk. + err = tlsManager.ensureEncryption(keyRing) + require.NoError(t, err, "failed to generate new certificate") + + // Grab the file from disk to check that the key is no longer + // plaintext. + _, newKeyBytes, err = cert.GetCertBytesFromPath( + cfg.TLSCertPath, cfg.TLSKeyPath, + ) + require.NoError(t, err, "unable to load certificate") + require.False(t, bytes.HasPrefix(newKeyBytes, privateKeyPrefix), + "key isn't encrypted, but should be") + + // Now let's flip the cfg.TLSEncryptKey to false. Since the key on file + // is encrypted, ensureEncryption should error out. + tlsManager.cfg.TLSEncryptKey = false + err = tlsManager.ensureEncryption(keyRing) + require.Error(t, err) +} + +// TestGenerateEphemeralCert tests that an ephemeral certificate is created and +// stored to disk in a .tmp file and that LoadPermanentCertificate deletes +// file and replaces it with a fresh certificate pair. +func TestGenerateEphemeralCert(t *testing.T) { + t.Parallel() + + _, certPath, keyPath := newTestDirectory(t) + tmpCertPath := certPath + ".tmp" + + cfg := &TLSManagerCfg{ + TLSCertPath: certPath, + TLSKeyPath: keyPath, + TLSEncryptKey: true, + TLSCertDuration: testTLSCertDuration, + } + tlsManager := NewTLSManager(cfg) + + keyBytes, err := tlsManager.loadEphemeralCertificate() + require.NoError(t, err, "failed to generate new certificate") + + certBytes, err := ioutil.ReadFile(tmpCertPath) + require.NoError(t, err) + + tlsr, err := cert.NewTLSReloader(certBytes, keyBytes) + require.NoError(t, err) + tlsManager.tlsReloader = tlsr + + // Make sure .tmp file is created at the tmp cert path. + _, err = ioutil.ReadFile(tmpCertPath) + require.NoError(t, err, "couldn't find temp cert file") + + // But no key should be stored. + _, err = ioutil.ReadFile(cfg.TLSKeyPath) + require.Error(t, err, "shouldn't have found file") + + // And no permanent cert file should be stored. + _, err = ioutil.ReadFile(cfg.TLSCertPath) + require.Error(t, err, "shouldn't have found a permanent cert file") + + // Now test that when we reload the certificate it generates the new + // certificate properly. + keyRing := &mock.SecretKeyRing{ + RootKey: privKey, + } + err = tlsManager.LoadPermanentCertificate(keyRing) + require.NoError(t, err, "unable to reload certificate") + + // Make sure .tmp file is deleted. + _, _, err = cert.GetCertBytesFromPath( + tmpCertPath, cfg.TLSKeyPath, + ) + require.Error(t, err, ".tmp file should have been deleted") + + // Make sure a certificate now exists at the permanent cert path. + _, _, err = cert.GetCertBytesFromPath( + cfg.TLSCertPath, cfg.TLSKeyPath, + ) + require.NoError(t, err, "error loading permanent certificate") +} + +// genCertPair generates a key/cert pair, with the option of generating expired +// certificates to make sure they are being regenerated correctly. +func genCertPair(t *testing.T, expired bool) ([]byte, []byte) { t.Helper() + // Max serial number. serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) @@ -113,6 +258,15 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) { dnsNames := []string{host, "unix", "unixpacket"} + var notBefore, notAfter time.Time + if expired { + notBefore = time.Now().Add(-time.Hour * 24) + notAfter = time.Now() + } else { + notBefore = time.Now() + notAfter = time.Now().Add(time.Hour * 24) + } + // Construct the certificate template. template := x509.Certificate{ SerialNumber: serialNumber, @@ -120,16 +274,14 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) { Organization: []string{"lnd autogenerated cert"}, CommonName: host, }, - NotBefore: time.Now().Add(-time.Hour * 24), - NotAfter: time.Now(), - + NotBefore: notBefore, + NotAfter: notAfter, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, IsCA: true, // so can sign self. BasicConstraintsValid: true, - - DNSNames: dnsNames, - IPAddresses: ipAddresses, + DNSNames: dnsNames, + IPAddresses: ipAddresses, } // Generate a private key for the certificate. @@ -148,3 +300,72 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) { return certDerBytes, keyBytes } + +// writeTestCertFiles creates test files and writes them to a temporary testing +// directory. +func writeTestCertFiles(t *testing.T, expiredCert, encryptTLSKey bool, + keyRing keychain.KeyRing) (string, string, *x509.Certificate) { + + t.Helper() + + tempDir, certPath, keyPath := newTestDirectory(t) + + var certDerBytes, keyBytes []byte + // Either create a valid certificate or an expired certificate pair, + // depending on the test. + if expiredCert { + certDerBytes, keyBytes = genCertPair(t, true) + } else { + certDerBytes, keyBytes = genCertPair(t, false) + } + + parsedCert, err := x509.ParseCertificate(certDerBytes) + require.NoError(t, err, "failed to parse certificate") + + certBuf := bytes.Buffer{} + err = pem.Encode( + &certBuf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certDerBytes, + }, + ) + require.NoError(t, err, "failed to encode certificate") + + var keyBuf *bytes.Buffer + if !encryptTLSKey { + keyBuf = &bytes.Buffer{} + err = pem.Encode( + keyBuf, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + }, + ) + require.NoError(t, err, "failed to encode private key") + } else { + e, err := lnencrypt.KeyRingEncrypter(keyRing) + require.NoError(t, err, "unable to generate key encrypter") + err = e.EncryptPayloadToWriter( + keyBytes, keyBuf, + ) + require.NoError(t, err, "failed to encrypt private key") + } + + err = ioutil.WriteFile(tempDir+"/tls.cert", certBuf.Bytes(), 0644) + require.NoError(t, err, "failed to write cert file") + err = ioutil.WriteFile(tempDir+"/tls.key", keyBuf.Bytes(), 0600) + require.NoError(t, err, "failed to write key file") + + return certPath, keyPath, parsedCert +} + +// newTestDirectory creates a new test directory and returns the location of +// the test tls.cert and tls.key files. +func newTestDirectory(t *testing.T) (string, string, string) { + t.Helper() + + tempDir := t.TempDir() + certPath := tempDir + "/tls.cert" + keyPath := tempDir + "/tls.key" + + return tempDir, certPath, keyPath +}