cert: add TLS reloader and return bytes from GenCert

Co-authored-by: gkrizek <graham@krizek.io>
This commit is contained in:
Orbital
2022-05-24 13:26:32 -05:00
parent 84401f6f6c
commit 2f35b9aa7f
5 changed files with 239 additions and 40 deletions

View File

@@ -1,6 +1,8 @@
package cert_test
import (
"io/ioutil"
"path/filepath"
"testing"
"time"
@@ -26,20 +28,28 @@ func TestIsOutdatedCert(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], false, testTLSCertDuration,
)
if err != nil {
t.Fatal(err)
}
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
// We'll attempt to check up-to-date status for all variants of 1-3
// number of IPs and domains.
for numIPs := 1; numIPs <= len(extraIPs); numIPs++ {
for numDomains := 1; numDomains <= len(extraDomains); numDomains++ {
_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
certBytes, err := ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err := ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
if err != nil {
t.Fatal(err)
@@ -78,17 +88,24 @@ func TestIsOutdatedPermutation(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files from the IPs and domains.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:],
extraDomains[:], false, testTLSCertDuration,
)
if err != nil {
t.Fatal(err)
}
_, parsedCert, err := cert.LoadCert(certPath, keyPath)
if err != nil {
t.Fatal(err)
}
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCertFromBytes(certBytes, keyBytes)
require.NoError(t, err)
// If we have duplicate IPs or DNS names listed, that shouldn't matter.
dupIPs := make([]string, len(extraIPs)*2)
@@ -142,7 +159,7 @@ func TestTLSDisableAutofill(t *testing.T) {
keyPath := tempDir + "/tls.key"
// Generate TLS files with two extra IPs and domains and no interface IPs.
err := cert.GenCertPair(
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], true, testTLSCertDuration,
)
@@ -150,9 +167,19 @@ func TestTLSDisableAutofill(t *testing.T) {
t, err,
"unable to generate tls certificate pair",
)
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
// Read certs from disk.
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
_, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
require.NoError(
t, err,
@@ -160,7 +187,7 @@ func TestTLSDisableAutofill(t *testing.T) {
)
// Check if the TLS cert is outdated while still preventing
// interface IPs from being used. Should not be outdated
// interface IPs from being used. Should not be outdated.
shouldNotBeOutdated, err := cert.IsOutdated(
parsedCert, extraIPs[:2],
extraDomains[:2], true,
@@ -185,3 +212,51 @@ func TestTLSDisableAutofill(t *testing.T) {
"TLS Certificate was not marked as outdated when it should be",
)
}
// TestTLSConfig tests to ensure we can generate a TLS Config from
// a tls cert and tls key.
func TestTLSConfig(t *testing.T) {
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "/tls.cert")
keyPath := filepath.Join(tempDir, "/tls.key")
// Generate TLS files with an extra IP and domain.
certBytes, keyBytes, err := cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath,
[]string{extraIPs[0]}, []string{extraDomains[0]}, false,
testTLSCertDuration,
)
require.NoError(t, err)
err = cert.WriteCertPair(certPath, keyPath, certBytes, keyBytes)
require.NoError(t, err)
certBytes, err = ioutil.ReadFile(certPath)
require.NoError(t, err)
keyBytes, err = ioutil.ReadFile(keyPath)
require.NoError(t, err)
// Load the certificate.
certData, parsedCert, err := cert.LoadCertFromBytes(
certBytes, keyBytes,
)
require.NoError(t, err)
// Check to make sure the IP and domain are in the cert.
var foundIp bool
require.Contains(t, parsedCert.DNSNames, extraDomains[0])
for _, ip := range parsedCert.IPAddresses {
if ip.String() == extraIPs[0] {
foundIp = true
break
}
}
require.Equal(t, true, foundIp, "Did not find required ip inside of "+
"TLS Certificate.")
// Create TLS Config.
tlsCfg := cert.TLSConfFromCert(certData)
require.Equal(t, 1, len(tlsCfg.Certificates))
}