From 10f9748193c2ccf14b5c60396971e65d7c104198 Mon Sep 17 00:00:00 2001 From: sputn1ck Date: Wed, 31 May 2023 12:41:25 +0200 Subject: [PATCH] tlsmanager: fix autocert autogeneration As the getConfig() function would previously overwrite the GetCertificateFunction of the tls config, the autocert manager would never be used. --- tls_manager.go | 33 +++++++++++++++++---------------- tls_manager_test.go | 6 +++--- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tls_manager.go b/tls_manager.go index 85defe837..577d2fc6d 100644 --- a/tls_manager.go +++ b/tls_manager.go @@ -94,7 +94,7 @@ func NewTLSManager(cfg *TLSManagerCfg) *TLSManager { // getConfig returns a TLS configuration for the gRPC server and credentials // and a proxy destination for the REST reverse proxy. func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, - func(net.Addr) (net.Listener, error), error) { + func(net.Addr) (net.Listener, error), func(), error) { var ( keyBytes, certBytes []byte @@ -108,19 +108,19 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, t.cfg.TLSCertPath, t.cfg.TLSKeyPath, ) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } } certData, _, err := cert.LoadCertFromBytes(certBytes, keyBytes) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } if t.tlsReloader == nil { tlsr, err := cert.NewTLSReloader(certBytes, keyBytes) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } t.tlsReloader = tlsr } @@ -128,6 +128,10 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, tlsCfg := cert.TLSConfFromCert(certData) tlsCfg.GetCertificate = t.tlsReloader.GetCertificateFunc() + // If Let's Encrypt is enabled, we need to set up the autocert manager + // and override the TLS config's GetCertificate function. + cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg) + // If we're using the ephemeral certificate, we need to use the // ephemeral cert path. certPath := t.cfg.TLSCertPath @@ -141,7 +145,7 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, certPath, "", ) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } serverCreds := credentials.NewTLS(tlsCfg) @@ -171,40 +175,39 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, return lncfg.TLSListenOnAddress(addr, tlsCfg) } - return serverOpts, restDialOpts, restListen, nil + return serverOpts, restDialOpts, restListen, cleanUp, nil } // generateOrRenewCert generates a new TLS certificate if we're not using one // yet or renews it if it's outdated. -func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { +func (t *TLSManager) generateOrRenewCert() (*tls.Config, error) { // Generete a TLS pair if we don't have one yet. var emptyKeyRing keychain.SecretKeyRing err := t.generateCertPair(emptyKeyRing) if err != nil { - return nil, nil, err + return nil, err } certData, parsedCert, err := cert.LoadCert( t.cfg.TLSCertPath, t.cfg.TLSKeyPath, ) if err != nil { - return nil, nil, err + return nil, err } // Check to see if the certificate needs to be renewed. If it does, we // return the newly generated certificate data instead. reloadedCertData, err := t.maintainCert(parsedCert) if err != nil { - return nil, nil, err + return nil, err } if reloadedCertData != nil { certData = *reloadedCertData } tlsCfg := cert.TLSConfFromCert(certData) - cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg) - return tlsCfg, cleanUp, nil + return tlsCfg, nil } // generateCertPair creates and writes a TLS pair to disk if the pair @@ -474,7 +477,6 @@ func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption, []grpc.DialOption, func(net.Addr) (net.Listener, error), func(), error) { - var cleanUp func() if t.cfg.TLSEncryptKey { _, err := t.loadEphemeralCertificate() if err != nil { @@ -482,15 +484,14 @@ func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption, "ephemeral certificate: %v", err) } } else { - _, cleanUpFunc, err := t.generateOrRenewCert() + _, err := t.generateOrRenewCert() if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to "+ "generate or renew TLS certificate: %v", err) } - cleanUp = cleanUpFunc } - serverOpts, restDialOpts, restListen, err := t.getConfig() + serverOpts, restDialOpts, restListen, cleanUp, err := t.getConfig() if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to load TLS "+ "credentials: %v", err) diff --git a/tls_manager_test.go b/tls_manager_test.go index fcd2eadbf..935834971 100644 --- a/tls_manager_test.go +++ b/tls_manager_test.go @@ -53,9 +53,9 @@ func TestGenerateOrRenewCert(t *testing.T) { TLSCertDuration: testTLSCertDuration, } tlsManager := NewTLSManager(cfg) - _, cleanUp, err := tlsManager.generateOrRenewCert() + _, err := tlsManager.generateOrRenewCert() require.NoError(t, err) - _, _, _, err = tlsManager.getConfig() + _, _, _, cleanUp, err := tlsManager.getConfig() require.NoError(t, err, "couldn't retrieve TLS config") t.Cleanup(cleanUp) @@ -86,7 +86,7 @@ func TestTLSManagerGenCert(t *testing.T) { } tlsManager := NewTLSManager(cfg) - _, _, err := tlsManager.generateOrRenewCert() + _, err := tlsManager.generateOrRenewCert() require.NoError(t, err, "failed to generate new certificate") // After this is run, a new certificate should be created and written