From 7e03297b1b5bab19550e0414700d48059d7394e7 Mon Sep 17 00:00:00 2001 From: Orbital Date: Mon, 7 Feb 2022 18:05:29 -0600 Subject: [PATCH] lnd: refactor getTLSConfig --- lnd.go | 347 ++++++++++++++++++++++++++++++------------------- server_test.go | 3 +- 2 files changed, 215 insertions(+), 135 deletions(-) diff --git a/lnd.go b/lnd.go index 11c7e539f..8746de9d5 100644 --- a/lnd.go +++ b/lnd.go @@ -7,6 +7,7 @@ package lnd import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "io/ioutil" @@ -217,11 +218,13 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, } // Only process macaroons if --no-macaroons isn't set. - serverOpts, restDialOpts, restListen, cleanUp, err := getTLSConfig(cfg) + tlsManager := NewTLSManager(cfg) + serverOpts, restDialOpts, restListen, cleanUp, + err := tlsManager.getConfig() + if err != nil { return mkErr("unable to load TLS credentials: %v", err) } - defer cleanUp() // If we have chosen to start with a dedicated listener for the @@ -630,147 +633,38 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, return nil } -// getTLSConfig returns a TLS configuration for the gRPC server and credentials +// TLSManager generates/renews a TLS certificate when needed and returns the +// certificate configuration options needed for gRPC and REST. +type TLSManager struct { + cfg *Config +} + +// NewTLSManager returns a reference to a new TLSManager. +func NewTLSManager(cfg *Config) *TLSManager { + return &TLSManager{ + cfg: cfg, + } +} + +// getConfig returns a TLS configuration for the gRPC server and credentials // and a proxy destination for the REST reverse proxy. -func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, +func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, func(net.Addr) (net.Listener, error), func(), error) { - // Ensure we create TLS key and certificate if they don't exist. - if !fileExists(cfg.TLSCertPath) && !fileExists(cfg.TLSKeyPath) { - rpcsLog.Infof("Generating TLS certificates...") - err := cert.GenCertPair( - "lnd autogenerated cert", cfg.TLSCertPath, - cfg.TLSKeyPath, cfg.TLSExtraIPs, cfg.TLSExtraDomains, - cfg.TLSDisableAutofill, cfg.TLSCertDuration, - ) - if err != nil { - return nil, nil, nil, nil, err - } - rpcsLog.Infof("Done generating TLS certificates") + tlsCfg, cleanUp, err := t.generateOrRenewCert() + if err != nil { + return nil, nil, nil, nil, err } - certData, parsedCert, err := cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, + // Now that we know that we have a ceritificate, let's generate the + // required config options. + restCreds, err := credentials.NewClientTLSFromFile( + t.cfg.TLSCertPath, "", ) if err != nil { return nil, nil, nil, nil, err } - // We check whether the certificate we have on disk match the IPs and - // domains specified by the config. If the extra IPs or domains have - // changed from when the certificate was created, we will refresh the - // certificate if auto refresh is active. - refresh := false - if cfg.TLSAutoRefresh { - refresh, err = cert.IsOutdated( - parsedCert, cfg.TLSExtraIPs, - cfg.TLSExtraDomains, cfg.TLSDisableAutofill, - ) - if err != nil { - return nil, nil, nil, nil, err - } - } - - // If the certificate expired or it was outdated, delete it and the TLS - // key and generate a new pair. - if time.Now().After(parsedCert.NotAfter) || refresh { - ltndLog.Info("TLS certificate is expired or outdated, " + - "generating a new one") - - err := os.Remove(cfg.TLSCertPath) - if err != nil { - return nil, nil, nil, nil, err - } - - err = os.Remove(cfg.TLSKeyPath) - if err != nil { - return nil, nil, nil, nil, err - } - - rpcsLog.Infof("Renewing TLS certificates...") - err = cert.GenCertPair( - "lnd autogenerated cert", cfg.TLSCertPath, - cfg.TLSKeyPath, cfg.TLSExtraIPs, cfg.TLSExtraDomains, - cfg.TLSDisableAutofill, cfg.TLSCertDuration, - ) - if err != nil { - return nil, nil, nil, nil, err - } - rpcsLog.Infof("Done renewing TLS certificates") - - // Reload the certificate data. - certData, _, err = cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, - ) - if err != nil { - return nil, nil, nil, nil, err - } - } - - tlsCfg := cert.TLSConfFromCert(certData) - - restCreds, err := credentials.NewClientTLSFromFile(cfg.TLSCertPath, "") - if err != nil { - return nil, nil, nil, nil, err - } - - // If Let's Encrypt is enabled, instantiate autocert to request/renew - // the certificates. - cleanUp := func() {} - if cfg.LetsEncryptDomain != "" { - ltndLog.Infof("Using Let's Encrypt certificate for domain %v", - cfg.LetsEncryptDomain) - - manager := autocert.Manager{ - Cache: autocert.DirCache(cfg.LetsEncryptDir), - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(cfg.LetsEncryptDomain), - } - - srv := &http.Server{ - Addr: cfg.LetsEncryptListen, - Handler: manager.HTTPHandler(nil), - } - shutdownCompleted := make(chan struct{}) - cleanUp = func() { - err := srv.Shutdown(context.Background()) - if err != nil { - ltndLog.Errorf("Autocert listener shutdown "+ - " error: %v", err) - - return - } - <-shutdownCompleted - ltndLog.Infof("Autocert challenge listener stopped") - } - - go func() { - ltndLog.Infof("Autocert challenge listener started "+ - "at %v", cfg.LetsEncryptListen) - - err := srv.ListenAndServe() - if err != http.ErrServerClosed { - ltndLog.Errorf("autocert http: %v", err) - } - close(shutdownCompleted) - }() - - getCertificate := func(h *tls.ClientHelloInfo) ( - *tls.Certificate, error) { - - lecert, err := manager.GetCertificate(h) - if err != nil { - ltndLog.Errorf("GetCertificate: %v", err) - return &certData, nil - } - - return lecert, err - } - - // The self-signed tls.cert remains available as fallback. - tlsCfg.GetCertificate = getCertificate - } - serverCreds := credentials.NewTLS(tlsCfg) serverOpts := []grpc.ServerOption{grpc.Creds(serverCreds)} @@ -791,7 +685,7 @@ func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, restListen := func(addr net.Addr) (net.Listener, error) { // For restListen we will call ListenOnAddress if TLS is // disabled. - if cfg.DisableRestTLS { + if t.cfg.DisableRestTLS { return lncfg.ListenOnAddress(addr) } @@ -801,6 +695,191 @@ func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, return serverOpts, restDialOpts, restListen, cleanUp, nil } +// generateOrRenewCert generates a new TLS certificate if we're not using one +// yet or renews it if it's outdated. +func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { + // Genereate a TLS pair if we don't have one yet. + err := t.createTLSPair() + if err != nil { + return nil, nil, err + } + + certData, parsedCert, err := cert.LoadCert( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return nil, nil, err + } + + // Check to see if the certificate needs to be renewed. If it does, we + // return the newly generated certificate data instead. + reloadedCertData, err := t.certMaintenance(parsedCert) + if err != nil { + return nil, nil, err + } + if reloadedCertData != nil { + certData = *reloadedCertData + } + + tlsCfg := cert.TLSConfFromCert(certData) + + cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg) + + return tlsCfg, cleanUp, nil +} + +// If a TLS pair doesn't exist yet, create them and write them to disk. +func (t *TLSManager) createTLSPair() error { + // Ensure we create TLS key and certificate if they don't exist. + if fileExists(t.cfg.TLSCertPath) || fileExists(t.cfg.TLSKeyPath) { + return nil + } + + rpcsLog.Infof("Generating TLS certificates...") + err := cert.GenCertPair( + "lnd autogenerated cert", t.cfg.TLSCertPath, + t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + t.cfg.TLSCertDuration, + ) + rpcsLog.Infof("Done generating TLS certificates") + + return err +} + +// certMaintenance checks if the certificate IP and domains matches the config, +// and renews the certificate if either this data is outdated or the +// certificate is expired. +func (t *TLSManager) certMaintenance( + parsedCert *x509.Certificate) (*tls.Certificate, error) { + + // We check whether the certificate we have on disk match the IPs and + // domains specified by the config. If the extra IPs or domains have + // changed from when the certificate was created, we will refresh the + // certificate if auto refresh is active. + refresh := false + var err error + if t.cfg.TLSAutoRefresh { + refresh, err = cert.IsOutdated( + parsedCert, t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + ) + if err != nil { + return nil, err + } + } + + // If the certificate expired or it was outdated, delete it and the TLS + // key and generate a new pair. + if !time.Now().After(parsedCert.NotAfter) && !refresh { + return nil, nil + } + + ltndLog.Info("TLS certificate is expired or outdated, " + + "generating a new one") + + err = os.Remove(t.cfg.TLSCertPath) + if err != nil { + return nil, err + } + + err = os.Remove(t.cfg.TLSKeyPath) + if err != nil { + return nil, err + } + + rpcsLog.Infof("Renewing TLS certificates...") + err = cert.GenCertPair( + "lnd autogenerated cert", t.cfg.TLSCertPath, + t.cfg.TLSKeyPath, t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + t.cfg.TLSCertDuration, + ) + if err != nil { + return nil, err + } + rpcsLog.Infof("Done renewing TLS certificates") + + // Reload the certificate data. + reloadedCertData, _, err := cert.LoadCert( + t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + ) + if err != nil { + return nil, err + } + + return &reloadedCertData, nil +} + +// setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the +// option is set. +func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate, + tlsCfg *tls.Config) func() { + + // If Let's Encrypt is enabled, instantiate autocert to request/renew + // the certificates. + cleanUp := func() {} + if t.cfg.LetsEncryptDomain == "" { + return cleanUp + } + + ltndLog.Infof("Using Let's Encrypt certificate for domain %v", + t.cfg.LetsEncryptDomain) + + manager := autocert.Manager{ + Cache: autocert.DirCache(t.cfg.LetsEncryptDir), + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist( + t.cfg.LetsEncryptDomain, + ), + } + + srv := &http.Server{ + Addr: t.cfg.LetsEncryptListen, + Handler: manager.HTTPHandler(nil), + } + shutdownCompleted := make(chan struct{}) + cleanUp = func() { + err := srv.Shutdown(context.Background()) + if err != nil { + ltndLog.Errorf("Autocert listener shutdown "+ + " error: %v", err) + + return + } + <-shutdownCompleted + ltndLog.Infof("Autocert challenge listener stopped") + } + + go func() { + ltndLog.Infof("Autocert challenge listener started "+ + "at %v", t.cfg.LetsEncryptListen) + + err := srv.ListenAndServe() + if err != http.ErrServerClosed { + ltndLog.Errorf("autocert http: %v", err) + } + close(shutdownCompleted) + }() + + getCertificate := func(h *tls.ClientHelloInfo) ( + *tls.Certificate, error) { + + lecert, err := manager.GetCertificate(h) + if err != nil { + ltndLog.Errorf("GetCertificate: %v", err) + return certData, nil + } + + return lecert, err + } + + // The self-signed tls.cert remains available as fallback. + tlsCfg.GetCertificate = getCertificate + + return cleanUp +} + // fileExists reports whether the named file or directory exists. // This function is taken from https://github.com/btcsuite/btcd func fileExists(name string) bool { diff --git a/server_test.go b/server_test.go index 6c570e332..779e6414a 100644 --- a/server_test.go +++ b/server_test.go @@ -71,7 +71,8 @@ func TestTLSAutoRegeneration(t *testing.T) { TLSCertDuration: 42 * time.Hour, RPCListeners: rpcListeners, } - _, _, _, cleanUp, err := getTLSConfig(cfg) + tlsManager := NewTLSManager(cfg) + _, _, _, cleanUp, err := tlsManager.getConfig() if err != nil { t.Fatalf("couldn't retrieve TLS config") }