Merge pull request #3237 from orbitalturtle/auto-regenerate-cert

Unit test for autoregenerating expired cert pairs
This commit is contained in:
Olaoluwa Osuntokun
2019-07-19 17:21:27 -07:00
committed by GitHub
3 changed files with 183 additions and 157 deletions

43
lnd.go
View File

@@ -184,7 +184,11 @@ func Main() error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
tlsCfg, restCreds, restProxyDest, err := getTLSConfig(cfg)
tlsCfg, restCreds, restProxyDest, err := getTLSConfig(
cfg.TLSCertPath,
cfg.TLSKeyPath,
cfg.RPCListeners,
)
if err != nil {
err := fmt.Errorf("Unable to load TLS credentials: %v", err)
ltndLog.Error(err)
@@ -551,18 +555,19 @@ func Main() error {
// getTLSConfig returns a TLS configuration for the gRPC server and credentials
// and a proxy destination for the REST reverse proxy.
func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
string, error) {
func getTLSConfig(tlsCertPath string, tlsKeyPath string,
rpcListeners []net.Addr) (*tls.Config,
*credentials.TransportCredentials, string, error) {
// Ensure we create TLS key and certificate if they don't exist
if !fileExists(cfg.TLSCertPath) && !fileExists(cfg.TLSKeyPath) {
err := genCertPair(cfg.TLSCertPath, cfg.TLSKeyPath)
if !fileExists(tlsCertPath) && !fileExists(tlsKeyPath) {
err := genCertPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
}
certData, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
certData, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
@@ -576,17 +581,17 @@ func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
if time.Now().After(cert.NotAfter) {
ltndLog.Info("TLS certificate is expired, generating a new one")
err := os.Remove(cfg.TLSCertPath)
err := os.Remove(tlsCertPath)
if err != nil {
return nil, nil, "", err
}
err = os.Remove(cfg.TLSKeyPath)
err = os.Remove(tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
err = genCertPair(cfg.TLSCertPath, cfg.TLSKeyPath)
err = genCertPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
@@ -599,12 +604,12 @@ func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
MinVersion: tls.VersionTLS12,
}
restCreds, err := credentials.NewClientTLSFromFile(cfg.TLSCertPath, "")
restCreds, err := credentials.NewClientTLSFromFile(tlsCertPath, "")
if err != nil {
return nil, nil, "", err
}
restProxyDest := cfg.RPCListeners[0].String()
restProxyDest := rpcListeners[0].String()
switch {
case strings.Contains(restProxyDest, "0.0.0.0"):
restProxyDest = strings.Replace(
@@ -682,11 +687,13 @@ func genCertPair(certFile, keyFile string) error {
}
}
// Add extra IPs to the slice.
for _, ip := range cfg.TLSExtraIPs {
ipAddr := net.ParseIP(ip)
if ipAddr != nil {
addIP(ipAddr)
if cfg != nil {
// Add extra IPs to the slice.
for _, ip := range cfg.TLSExtraIPs {
ipAddr := net.ParseIP(ip)
if ipAddr != nil {
addIP(ipAddr)
}
}
}
@@ -702,7 +709,9 @@ func genCertPair(certFile, keyFile string) error {
if host != "localhost" {
dnsNames = append(dnsNames, "localhost")
}
dnsNames = append(dnsNames, cfg.TLSExtraDomains...)
if cfg != nil {
dnsNames = append(dnsNames, cfg.TLSExtraDomains...)
}
// Also add fake hostnames for unix sockets, otherwise hostname
// verification will fail in the client.