tlsmanager: fix autocert autogeneration

As the getConfig() function would previously overwrite the
GetCertificateFunction of the tls config, the autocert manager would
never be used.
This commit is contained in:
sputn1ck
2023-05-31 12:41:25 +02:00
parent 33b470b4a6
commit 10f9748193
2 changed files with 20 additions and 19 deletions

View File

@@ -94,7 +94,7 @@ func NewTLSManager(cfg *TLSManagerCfg) *TLSManager {
// getConfig returns a TLS configuration for the gRPC server and credentials // getConfig returns a TLS configuration for the gRPC server and credentials
// and a proxy destination for the REST reverse proxy. // and a proxy destination for the REST reverse proxy.
func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
func(net.Addr) (net.Listener, error), error) { func(net.Addr) (net.Listener, error), func(), error) {
var ( var (
keyBytes, certBytes []byte keyBytes, certBytes []byte
@@ -108,19 +108,19 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
t.cfg.TLSCertPath, t.cfg.TLSKeyPath, t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
} }
certData, _, err := cert.LoadCertFromBytes(certBytes, keyBytes) certData, _, err := cert.LoadCertFromBytes(certBytes, keyBytes)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
if t.tlsReloader == nil { if t.tlsReloader == nil {
tlsr, err := cert.NewTLSReloader(certBytes, keyBytes) tlsr, err := cert.NewTLSReloader(certBytes, keyBytes)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
t.tlsReloader = tlsr t.tlsReloader = tlsr
} }
@@ -128,6 +128,10 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
tlsCfg := cert.TLSConfFromCert(certData) tlsCfg := cert.TLSConfFromCert(certData)
tlsCfg.GetCertificate = t.tlsReloader.GetCertificateFunc() tlsCfg.GetCertificate = t.tlsReloader.GetCertificateFunc()
// If Let's Encrypt is enabled, we need to set up the autocert manager
// and override the TLS config's GetCertificate function.
cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg)
// If we're using the ephemeral certificate, we need to use the // If we're using the ephemeral certificate, we need to use the
// ephemeral cert path. // ephemeral cert path.
certPath := t.cfg.TLSCertPath certPath := t.cfg.TLSCertPath
@@ -141,7 +145,7 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
certPath, "", certPath, "",
) )
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
serverCreds := credentials.NewTLS(tlsCfg) serverCreds := credentials.NewTLS(tlsCfg)
@@ -171,40 +175,39 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption,
return lncfg.TLSListenOnAddress(addr, tlsCfg) return lncfg.TLSListenOnAddress(addr, tlsCfg)
} }
return serverOpts, restDialOpts, restListen, nil return serverOpts, restDialOpts, restListen, cleanUp, nil
} }
// generateOrRenewCert generates a new TLS certificate if we're not using one // generateOrRenewCert generates a new TLS certificate if we're not using one
// yet or renews it if it's outdated. // yet or renews it if it's outdated.
func (t *TLSManager) generateOrRenewCert() (*tls.Config, func(), error) { func (t *TLSManager) generateOrRenewCert() (*tls.Config, error) {
// Generete a TLS pair if we don't have one yet. // Generete a TLS pair if we don't have one yet.
var emptyKeyRing keychain.SecretKeyRing var emptyKeyRing keychain.SecretKeyRing
err := t.generateCertPair(emptyKeyRing) err := t.generateCertPair(emptyKeyRing)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
certData, parsedCert, err := cert.LoadCert( certData, parsedCert, err := cert.LoadCert(
t.cfg.TLSCertPath, t.cfg.TLSKeyPath, t.cfg.TLSCertPath, t.cfg.TLSKeyPath,
) )
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
// Check to see if the certificate needs to be renewed. If it does, we // Check to see if the certificate needs to be renewed. If it does, we
// return the newly generated certificate data instead. // return the newly generated certificate data instead.
reloadedCertData, err := t.maintainCert(parsedCert) reloadedCertData, err := t.maintainCert(parsedCert)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if reloadedCertData != nil { if reloadedCertData != nil {
certData = *reloadedCertData certData = *reloadedCertData
} }
tlsCfg := cert.TLSConfFromCert(certData) tlsCfg := cert.TLSConfFromCert(certData)
cleanUp := t.setUpLetsEncrypt(&certData, tlsCfg)
return tlsCfg, cleanUp, nil return tlsCfg, nil
} }
// generateCertPair creates and writes a TLS pair to disk if the pair // generateCertPair creates and writes a TLS pair to disk if the pair
@@ -474,7 +477,6 @@ func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption,
[]grpc.DialOption, func(net.Addr) (net.Listener, error), func(), []grpc.DialOption, func(net.Addr) (net.Listener, error), func(),
error) { error) {
var cleanUp func()
if t.cfg.TLSEncryptKey { if t.cfg.TLSEncryptKey {
_, err := t.loadEphemeralCertificate() _, err := t.loadEphemeralCertificate()
if err != nil { if err != nil {
@@ -482,15 +484,14 @@ func (t *TLSManager) SetCertificateBeforeUnlock() ([]grpc.ServerOption,
"ephemeral certificate: %v", err) "ephemeral certificate: %v", err)
} }
} else { } else {
_, cleanUpFunc, err := t.generateOrRenewCert() _, err := t.generateOrRenewCert()
if err != nil { if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to "+ return nil, nil, nil, nil, fmt.Errorf("unable to "+
"generate or renew TLS certificate: %v", err) "generate or renew TLS certificate: %v", err)
} }
cleanUp = cleanUpFunc
} }
serverOpts, restDialOpts, restListen, err := t.getConfig() serverOpts, restDialOpts, restListen, cleanUp, err := t.getConfig()
if err != nil { if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to load TLS "+ return nil, nil, nil, nil, fmt.Errorf("unable to load TLS "+
"credentials: %v", err) "credentials: %v", err)

View File

@@ -53,9 +53,9 @@ func TestGenerateOrRenewCert(t *testing.T) {
TLSCertDuration: testTLSCertDuration, TLSCertDuration: testTLSCertDuration,
} }
tlsManager := NewTLSManager(cfg) tlsManager := NewTLSManager(cfg)
_, cleanUp, err := tlsManager.generateOrRenewCert() _, err := tlsManager.generateOrRenewCert()
require.NoError(t, err) require.NoError(t, err)
_, _, _, err = tlsManager.getConfig() _, _, _, cleanUp, err := tlsManager.getConfig()
require.NoError(t, err, "couldn't retrieve TLS config") require.NoError(t, err, "couldn't retrieve TLS config")
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
@@ -86,7 +86,7 @@ func TestTLSManagerGenCert(t *testing.T) {
} }
tlsManager := NewTLSManager(cfg) tlsManager := NewTLSManager(cfg)
_, _, err := tlsManager.generateOrRenewCert() _, err := tlsManager.generateOrRenewCert()
require.NoError(t, err, "failed to generate new certificate") require.NoError(t, err, "failed to generate new certificate")
// After this is run, a new certificate should be created and written // After this is run, a new certificate should be created and written