Lnd + server_test: Add unit test for TLS cert autoregeneration

This commit is contained in:
Turtle
2019-06-23 00:07:10 -04:00
parent 24ca95ab10
commit f958555ce3
2 changed files with 183 additions and 18 deletions

43
lnd.go
View File

@ -181,7 +181,11 @@ func Main() error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
tlsCfg, restCreds, restProxyDest, err := getTLSConfig(cfg)
tlsCfg, restCreds, restProxyDest, err := getTLSConfig(
cfg.TLSCertPath,
cfg.TLSKeyPath,
cfg.RPCListeners,
)
if err != nil {
return err
}
@ -503,18 +507,19 @@ func Main() error {
// getTLSConfig returns a TLS configuration for the gRPC server and credentials
// and a proxy destination for the REST reverse proxy.
func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
string, error) {
func getTLSConfig(tlsCertPath string, tlsKeyPath string,
rpcListeners []net.Addr) (*tls.Config,
*credentials.TransportCredentials, string, error) {
// Ensure we create TLS key and certificate if they don't exist
if !fileExists(cfg.TLSCertPath) && !fileExists(cfg.TLSKeyPath) {
err := genCertPair(cfg.TLSCertPath, cfg.TLSKeyPath)
if !fileExists(tlsCertPath) && !fileExists(tlsKeyPath) {
err := genCertPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
}
certData, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
certData, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
@ -528,17 +533,17 @@ func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
if time.Now().After(cert.NotAfter) {
ltndLog.Info("TLS certificate is expired, generating a new one")
err := os.Remove(cfg.TLSCertPath)
err := os.Remove(tlsCertPath)
if err != nil {
return nil, nil, "", err
}
err = os.Remove(cfg.TLSKeyPath)
err = os.Remove(tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
err = genCertPair(cfg.TLSCertPath, cfg.TLSKeyPath)
err = genCertPair(tlsCertPath, tlsKeyPath)
if err != nil {
return nil, nil, "", err
}
@ -551,12 +556,12 @@ func getTLSConfig(cfg *config) (*tls.Config, *credentials.TransportCredentials,
MinVersion: tls.VersionTLS12,
}
restCreds, err := credentials.NewClientTLSFromFile(cfg.TLSCertPath, "")
restCreds, err := credentials.NewClientTLSFromFile(tlsCertPath, "")
if err != nil {
return nil, nil, "", err
}
restProxyDest := cfg.RPCListeners[0].String()
restProxyDest := rpcListeners[0].String()
switch {
case strings.Contains(restProxyDest, "0.0.0.0"):
restProxyDest = strings.Replace(
@ -634,11 +639,13 @@ func genCertPair(certFile, keyFile string) error {
}
}
// Add extra IPs to the slice.
for _, ip := range cfg.TLSExtraIPs {
ipAddr := net.ParseIP(ip)
if ipAddr != nil {
addIP(ipAddr)
if cfg != nil {
// Add extra IPs to the slice.
for _, ip := range cfg.TLSExtraIPs {
ipAddr := net.ParseIP(ip)
if ipAddr != nil {
addIP(ipAddr)
}
}
}
@ -654,7 +661,9 @@ func genCertPair(certFile, keyFile string) error {
if host != "localhost" {
dnsNames = append(dnsNames, "localhost")
}
dnsNames = append(dnsNames, cfg.TLSExtraDomains...)
if cfg != nil {
dnsNames = append(dnsNames, cfg.TLSExtraDomains...)
}
// Also add fake hostnames for unix sockets, otherwise hostname
// verification will fail in the client.