diff --git a/lnencrypt/crypto.go b/lnencrypt/crypto.go index 5035e0c34..0ed773f61 100644 --- a/lnencrypt/crypto.go +++ b/lnencrypt/crypto.go @@ -5,8 +5,8 @@ import ( "crypto/sha256" "fmt" "io" - "io/ioutil" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/keychain" "golang.org/x/crypto/chacha20poly1305" ) @@ -69,6 +69,25 @@ func KeyRingEncrypter(keyRing keychain.KeyRing) (*Encrypter, error) { }, nil } +// ECDHEncrypter derives an encryption key by performing an ECDH operation on +// the passed keys. The resulting key is used to encrypt or decrypt files with +// sensitive content. +func ECDHEncrypter(localKey *btcec.PrivateKey, + remoteKey *btcec.PublicKey) (*Encrypter, error) { + + ecdh := keychain.PrivKeyECDH{ + PrivKey: localKey, + } + encryptionKey, err := ecdh.ECDH(remoteKey) + if err != nil { + return nil, fmt.Errorf("error deriving encryption key: %w", err) + } + + return &Encrypter{ + encryptionKey: encryptionKey[:], + }, nil +} + // EncryptPayloadToWriter attempts to write the set of provided bytes into the // passed io.Writer in an encrypted form. We use a 24-byte chachapoly AEAD // instance with a randomized nonce that's pre-pended to the final payload and @@ -112,7 +131,7 @@ func (e Encrypter) DecryptPayloadFromReader(payload io.Reader) ([]byte, // Next, we'll read out the entire blob as we need to isolate the nonce // from the rest of the ciphertext. - packedPayload, err := ioutil.ReadAll(payload) + packedPayload, err := io.ReadAll(payload) if err != nil { return nil, err } diff --git a/lnencrypt/crypto_test.go b/lnencrypt/crypto_test.go index 4a41af328..42ebe1cc2 100644 --- a/lnencrypt/crypto_test.go +++ b/lnencrypt/crypto_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/stretchr/testify/require" ) @@ -35,7 +36,8 @@ func TestEncryptDecryptPayload(t *testing.T) { { plaintext: []byte("payload test plain text"), mutator: func(p *[]byte) { - // Flip a byte in the payload to render it invalid. + // Flip a byte in the payload to render it + // invalid. (*p)[0] ^= 1 }, valid: false, @@ -53,54 +55,55 @@ func TestEncryptDecryptPayload(t *testing.T) { } keyRing := &MockKeyRing{} + keyRingEnc, err := KeyRingEncrypter(keyRing) + require.NoError(t, err) - for i, payloadCase := range payloadCases { - var cipherBuffer bytes.Buffer - encrypter, err := KeyRingEncrypter(keyRing) - require.NoError(t, err) + _, pubKey := btcec.PrivKeyFromBytes([]byte{0x01, 0x02, 0x03, 0x04}) - // First, we'll encrypt the passed payload with our scheme. - err = encrypter.EncryptPayloadToWriter( - payloadCase.plaintext, &cipherBuffer, - ) - if err != nil { - t.Fatalf("unable encrypt paylaod: %v", err) - } + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + privKeyEnc, err := ECDHEncrypter(privKey, pubKey) + require.NoError(t, err) - // If we have a mutator, then we'll wrong the mutator over the - // cipher text, then reset the main buffer and re-write the new - // cipher text. - if payloadCase.mutator != nil { - cipherText := cipherBuffer.Bytes() + for _, payloadCase := range payloadCases { + payloadCase := payloadCase + for _, enc := range []*Encrypter{keyRingEnc, privKeyEnc} { + enc := enc - payloadCase.mutator(&cipherText) + // First, we'll encrypt the passed payload with our + // scheme. + var cipherBuffer bytes.Buffer + err = enc.EncryptPayloadToWriter( + payloadCase.plaintext, &cipherBuffer, + ) + require.NoError(t, err) - cipherBuffer.Reset() - cipherBuffer.Write(cipherText) - } + // If we have a mutator, then we'll wrong the mutator + // over the cipher text, then reset the main buffer and + // re-write the new cipher text. + if payloadCase.mutator != nil { + cipherText := cipherBuffer.Bytes() - plaintext, err := encrypter.DecryptPayloadFromReader( - &cipherBuffer, - ) + payloadCase.mutator(&cipherText) - switch { - // If this was meant to be a valid decryption, but we failed, - // then we'll return an error. - case err != nil && payloadCase.valid: - t.Fatalf("unable to decrypt valid payload case %v", i) + cipherBuffer.Reset() + cipherBuffer.Write(cipherText) + } - // If this was meant to be an invalid decryption, and we didn't - // fail, then we'll return an error. - case err == nil && !payloadCase.valid: - t.Fatalf("payload was invalid yet was able to decrypt") - } + plaintext, err := enc.DecryptPayloadFromReader( + &cipherBuffer, + ) - // Only if this case was mean to be valid will we ensure the - // resulting decrypted plaintext matches the original input. - if payloadCase.valid && - !bytes.Equal(plaintext, payloadCase.plaintext) { - t.Fatalf("#%v: expected %v, got %v: ", i, - payloadCase.plaintext, plaintext) + if !payloadCase.valid { + require.Error(t, err) + + continue + } + + require.NoError(t, err) + require.Equal( + t, plaintext, payloadCase.plaintext, + ) } } } @@ -111,7 +114,5 @@ func TestInvalidKeyGeneration(t *testing.T) { t.Parallel() _, err := KeyRingEncrypter(&MockKeyRing{true}) - if err == nil { - t.Fatal("expected error due to fail key gen") - } + require.Error(t, err) }