lnencrypt: Moves the crypto functions in the chanbackup package into its own package called lnencrypt

The functions inside of the crypto.go file in chanbackup (like EncryptPayloadToWriter and DecryptPayloadFromReader) can be used by a lot of things outside of just the chanbackup package. We can't just reference them directly from the chanbackup package because it's likely that it would generate circular dependencies. Therefore we need to move these functions into their own package to be referenced by chanbackup and whatever new functionality that needs them
This commit is contained in:
Graham Krizek
2020-07-11 22:29:00 -05:00
committed by Orbital
parent f3bd2227fa
commit e0fc5bb234
10 changed files with 152 additions and 124 deletions

View File

@ -9,6 +9,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -184,7 +185,7 @@ func assertMultiEqual(t *testing.T, a, b *Multi) {
func TestExtractMulti(t *testing.T) { func TestExtractMulti(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// First, as prep, we'll create a single chan backup, then pack that // First, as prep, we'll create a single chan backup, then pack that
// fully into a multi backup. // fully into a multi backup.

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -89,7 +90,12 @@ func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
// With the plaintext multi backup assembled, we'll now encrypt it // With the plaintext multi backup assembled, we'll now encrypt it
// directly to the passed writer. // directly to the passed writer.
return encryptPayloadToWriter(multiBackupBuffer, w, keyRing) e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
return e.EncryptPayloadToWriter(multiBackupBuffer.Bytes(), w)
} }
// UnpackFromReader attempts to unpack (decrypt+deserialize) a packed // UnpackFromReader attempts to unpack (decrypt+deserialize) a packed
@ -99,7 +105,11 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
// We'll attempt to read the entire packed backup, and also decrypt it // We'll attempt to read the entire packed backup, and also decrypt it
// using the passed key ring which is expected to be able to derive the // using the passed key ring which is expected to be able to derive the
// encryption keys. // encryption keys.
plaintextBackup, err := decryptPayloadFromReader(r, keyRing) e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
plaintextBackup, err := e.DecryptPayloadFromReader(r)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -27,7 +28,7 @@ func TestMultiPackUnpack(t *testing.T) {
multi.StaticBackups = append(multi.StaticBackups, single) multi.StaticBackups = append(multi.StaticBackups, single)
} }
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
versionTestCases := []struct { versionTestCases := []struct {
// version is the pack/unpack version that we should use to // version is the pack/unpack version that we should use to
@ -93,14 +94,17 @@ func TestMultiPackUnpack(t *testing.T) {
) )
} }
encrypter, err := lnencrypt.KeyRingEncrypter(keyRing)
require.NoError(t, err)
// Next, we'll make a fake packed multi, it'll have an // Next, we'll make a fake packed multi, it'll have an
// unknown version relative to what's implemented atm. // unknown version relative to what's implemented atm.
var fakePackedMulti bytes.Buffer var fakePackedMulti bytes.Buffer
fakeRawMulti := bytes.NewBuffer( fakeRawMulti := bytes.NewBuffer(
bytes.Repeat([]byte{99}, 20), bytes.Repeat([]byte{99}, 20),
) )
err := encryptPayloadToWriter( err = encrypter.EncryptPayloadToWriter(
*fakeRawMulti, &fakePackedMulti, keyRing, fakeRawMulti.Bytes(), &fakePackedMulti,
) )
if err != nil { if err != nil {
t.Fatalf("unable to pack fake multi; %v", err) t.Fatalf("unable to pack fake multi; %v", err)
@ -124,7 +128,7 @@ func TestMultiPackUnpack(t *testing.T) {
func TestPackedMultiUnpack(t *testing.T) { func TestPackedMultiUnpack(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// First, we'll make a new unpacked multi with a random channel. // First, we'll make a new unpacked multi with a random channel.
testChannel, err := genRandomOpenChannelShell() testChannel, err := genRandomOpenChannelShell()

View File

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -80,7 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
func TestNewSubSwapperSubscribeFail(t *testing.T) { func TestNewSubSwapperSubscribeFail(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
var swapper mockSwapper var swapper mockSwapper
chanNotifier := mockChannelNotifier{ chanNotifier := mockChannelNotifier{
@ -152,7 +153,7 @@ func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
func TestSubSwapperIdempotentStartStop(t *testing.T) { func TestSubSwapperIdempotentStartStop(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
var chanNotifier mockChannelNotifier var chanNotifier mockChannelNotifier
@ -181,7 +182,7 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) {
func TestSubSwapperUpdater(t *testing.T) { func TestSubSwapperUpdater(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
chanNotifier := newMockChannelNotifier() chanNotifier := newMockChannelNotifier()
swapper := newMockSwapper(keyRing) swapper := newMockSwapper(keyRing)

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -49,7 +50,7 @@ func (m *mockPeerConnector) ConnectPeer(node *btcec.PublicKey,
func TestUnpackAndRecoverSingles(t *testing.T) { func TestUnpackAndRecoverSingles(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a number of single chan backups that we'll // First, we'll create a number of single chan backups that we'll
// shortly back to so we can begin our recovery attempt. // shortly back to so we can begin our recovery attempt.
@ -123,7 +124,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) {
} }
// If we modify the keyRing, then unpacking should fail. // If we modify the keyRing, then unpacking should fail.
keyRing.fail = true keyRing.Fail = true
err = UnpackAndRecoverSingles( err = UnpackAndRecoverSingles(
packedBackups, keyRing, &chanRestorer, &peerConnector, packedBackups, keyRing, &chanRestorer, &peerConnector,
) )
@ -139,7 +140,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) {
func TestUnpackAndRecoverMulti(t *testing.T) { func TestUnpackAndRecoverMulti(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a number of single chan backups that we'll // First, we'll create a number of single chan backups that we'll
// shortly back to so we can begin our recovery attempt. // shortly back to so we can begin our recovery attempt.
@ -217,7 +218,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) {
} }
// If we modify the keyRing, then unpacking should fail. // If we modify the keyRing, then unpacking should fail.
keyRing.fail = true keyRing.Fail = true
err = UnpackAndRecoverMulti( err = UnpackAndRecoverMulti(
packedMulti, keyRing, &chanRestorer, &peerConnector, packedMulti, keyRing, &chanRestorer, &peerConnector,
) )

View File

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -351,7 +352,11 @@ func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
// Finally, we'll encrypt the raw serialized SCB (using the nonce as // Finally, we'll encrypt the raw serialized SCB (using the nonce as
// associated data), and write out the ciphertext prepend with the // associated data), and write out the ciphertext prepend with the
// nonce that we used to the passed io.Reader. // nonce that we used to the passed io.Reader.
return encryptPayloadToWriter(rawBytes, w, keyRing) e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
return e.EncryptPayloadToWriter(rawBytes.Bytes(), w)
} }
// readLocalKeyDesc reads a KeyDescriptor encoded within an unpacked Single. // readLocalKeyDesc reads a KeyDescriptor encoded within an unpacked Single.
@ -528,7 +533,11 @@ func (s *Single) Deserialize(r io.Reader) error {
// payload for whatever reason (wrong key, wrong nonce, etc), then this method // payload for whatever reason (wrong key, wrong nonce, etc), then this method
// will return an error. // will return an error.
func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
plaintext, err := decryptPayloadFromReader(r, keyRing) e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate key decrypter %v", err)
}
plaintext, err := e.DecryptPayloadFromReader(r)
if err != nil { if err != nil {
return err return err
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -207,7 +208,7 @@ func TestSinglePackUnpack(t *testing.T) {
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2}) singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
versionTestCases := []struct { versionTestCases := []struct {
// version is the pack/unpack version that we should use to // version is the pack/unpack version that we should use to
@ -312,7 +313,7 @@ func TestSinglePackUnpack(t *testing.T) {
func TestPackedSinglesUnpack(t *testing.T) { func TestPackedSinglesUnpack(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// To start, we'll create 10 new singles, and them assemble their // To start, we'll create 10 new singles, and them assemble their
// packed forms into a slice. // packed forms into a slice.
@ -361,7 +362,7 @@ func TestPackedSinglesUnpack(t *testing.T) {
func TestSinglePackStaticChanBackups(t *testing.T) { func TestSinglePackStaticChanBackups(t *testing.T) {
t.Parallel() t.Parallel()
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a set of random single, and along the way, // First, we'll create a set of random single, and along the way,
// create a map that will let us look up each single by its chan point. // create a map that will let us look up each single by its chan point.
@ -407,8 +408,9 @@ func TestSinglePackStaticChanBackups(t *testing.T) {
// If we attempt to pack again, but force the key ring to fail, then // If we attempt to pack again, but force the key ring to fail, then
// the entire method should fail. // the entire method should fail.
keyRing.Fail = true
_, err = PackStaticChanBackups( _, err = PackStaticChanBackups(
unpackedSingles, &mockKeyRing{true}, unpackedSingles, &lnencrypt.MockKeyRing{Fail: true},
) )
if err == nil { if err == nil {
t.Fatalf("pack attempt should fail") t.Fatalf("pack attempt should fail")
@ -432,7 +434,7 @@ func TestSingleUnconfirmedChannel(t *testing.T) {
channel.FundingBroadcastHeight = fundingBroadcastHeight channel.FundingBroadcastHeight = fundingBroadcastHeight
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2}) singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
keyRing := &mockKeyRing{} keyRing := &lnencrypt.MockKeyRing{}
// Pack it and then unpack it again to make sure everything is written // Pack it and then unpack it again to make sure everything is written
// correctly, then check that the block height of the unpacked // correctly, then check that the block height of the unpacked

View File

@ -1,7 +1,6 @@
package chanbackup package lnencrypt
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
@ -12,8 +11,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
// TODO(roasbeef): interface in front of?
// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base // baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base
// encryption key used for encrypting all payloads. We use this to then // encryption key used for encrypting all payloads. We use this to then
// derive the actual key that we'll use for encryption. We do this // derive the actual key that we'll use for encryption. We do this
@ -27,12 +24,32 @@ var baseEncryptionKeyLoc = keychain.KeyLocator{
Index: 0, Index: 0,
} }
// genEncryptionKey derives the key that we'll use to encrypt all of our static // EncrypterDecrypter is an interface representing an object that encrypts or
// channel backups. The key itself, is the sha2 of a base key that we get from // decrypts data.
// the keyring. We derive the key this way as we don't force the HSM (or any type EncrypterDecrypter interface {
// future abstractions) to be able to derive and know of the cipher that we'll // EncryptPayloadToWriter attempts to write the set of provided bytes
// use within our protocol. // into the passed io.Writer in an encrypted form.
func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) { EncryptPayloadToWriter([]byte, io.Writer) error
// DecryptPayloadFromReader attempts to decrypt the encrypted bytes
// within the passed io.Reader instance using the key derived from
// the passed keyRing.
DecryptPayloadFromReader(io.Reader) ([]byte, error)
}
// Encrypter is a struct responsible for encrypting and decrypting data.
type Encrypter struct {
encryptionKey []byte
}
// KeyRingEncrypter derives an encryption key to encrypt all our files that are
// written to disk and returns an Encrypter object holding the key.
//
// The key itself, is the sha2 of a base key that we get from the keyring. We
// derive the key this way as we don't force the HSM (or any future
// abstractions) to be able to derive and know of the cipher that we'll use
// within our protocol.
func KeyRingEncrypter(keyRing keychain.KeyRing) (*Encrypter, error) {
// key = SHA256(baseKey) // key = SHA256(baseKey)
baseKey, err := keyRing.DeriveKey( baseKey, err := keyRing.DeriveKey(
baseEncryptionKeyLoc, baseEncryptionKeyLoc,
@ -47,33 +64,23 @@ func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) {
// TODO(roasbeef): throw back in ECDH? // TODO(roasbeef): throw back in ECDH?
return encryptionKey[:], nil return &Encrypter{
encryptionKey: encryptionKey[:],
}, nil
} }
// encryptPayloadToWriter attempts to write the set of bytes contained within // EncryptPayloadToWriter attempts to write the set of provided bytes into the
// the passed byes.Buffer into the passed io.Writer in an encrypted form. We // passed io.Writer in an encrypted form. We use a 24-byte chachapoly AEAD
// use a 24-byte chachapoly AEAD instance with a randomized nonce that's // instance with a randomized nonce that's pre-pended to the final payload and
// pre-pended to the final payload and used as associated data in the AEAD. We // used as associated data in the AEAD.
// use the passed keyRing to generate the encryption key, see genEncryptionKey func (e Encrypter) EncryptPayloadToWriter(payload []byte,
// for further details. w io.Writer) error {
func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
keyRing keychain.KeyRing) error {
// First, we'll derive the key that we'll use to encrypt the payload
// for safe storage without giving away the details of any of our
// channels. The final operation is:
//
// key = SHA256(baseKey)
encryptionKey, err := genEncryptionKey(keyRing)
if err != nil {
return err
}
// Before encryption, we'll initialize our cipher with the target // Before encryption, we'll initialize our cipher with the target
// encryption key, and also read out our random 24-byte nonce we use // encryption key, and also read out our random 24-byte nonce we use
// for encryption. Note that we use NewX, not New, as the latter // for encryption. Note that we use NewX, not New, as the latter
// version requires a 12-byte nonce, not a 24-byte nonce. // version requires a 12-byte nonce, not a 24-byte nonce.
cipher, err := chacha20poly1305.NewX(encryptionKey) cipher, err := chacha20poly1305.NewX(e.encryptionKey)
if err != nil { if err != nil {
return err return err
} }
@ -84,7 +91,7 @@ func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
// Finally, we encrypted the final payload, and write out our // Finally, we encrypted the final payload, and write out our
// ciphertext with nonce pre-pended. // ciphertext with nonce pre-pended.
ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:]) ciphertext := cipher.Seal(nil, nonce[:], payload, nonce[:])
if _, err := w.Write(nonce[:]); err != nil { if _, err := w.Write(nonce[:]); err != nil {
return err return err
@ -96,38 +103,30 @@ func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
return nil return nil
} }
// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the // DecryptPayloadFromReader attempts to decrypt the encrypted bytes within the
// passed io.Reader instance using the key derived from the passed keyRing. For // passed io.Reader instance using the key derived from the passed keyRing. For
// further details regarding the key derivation protocol, see the // further details regarding the key derivation protocol, see the
// genEncryptionKey method. // KeyRingEncrypter function.
func decryptPayloadFromReader(payload io.Reader, func (e Encrypter) DecryptPayloadFromReader(payload io.Reader) ([]byte,
keyRing keychain.KeyRing) ([]byte, error) { error) {
// First, we'll re-generate the encryption key that we use for all the
// SCBs.
encryptionKey, err := genEncryptionKey(keyRing)
if err != nil {
return nil, err
}
// Next, we'll read out the entire blob as we need to isolate the nonce // Next, we'll read out the entire blob as we need to isolate the nonce
// from the rest of the ciphertext. // from the rest of the ciphertext.
packedBackup, err := ioutil.ReadAll(payload) packedPayload, err := ioutil.ReadAll(payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(packedBackup) < chacha20poly1305.NonceSizeX { if len(packedPayload) < chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("payload size too small, must be at "+ return nil, fmt.Errorf("payload size too small, must be at "+
"least %v bytes", chacha20poly1305.NonceSizeX) "least %v bytes", chacha20poly1305.NonceSizeX)
} }
nonce := packedBackup[:chacha20poly1305.NonceSizeX] nonce := packedPayload[:chacha20poly1305.NonceSizeX]
ciphertext := packedBackup[chacha20poly1305.NonceSizeX:] ciphertext := packedPayload[chacha20poly1305.NonceSizeX:]
// Now that we have the cipher text and the nonce separated, we can go // Now that we have the cipher text and the nonce separated, we can go
// ahead and decrypt the final blob so we can properly serialized the // ahead and decrypt the final blob so we can properly serialize.
// SCB. cipher, err := chacha20poly1305.NewX(e.encryptionKey)
cipher, err := chacha20poly1305.NewX(encryptionKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,41 +1,12 @@
package chanbackup package lnencrypt
import ( import (
"bytes" "bytes"
"fmt"
"testing" "testing"
"github.com/btcsuite/btcd/btcec/v2" "github.com/stretchr/testify/require"
"github.com/lightningnetwork/lnd/keychain"
) )
var (
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
)
type mockKeyRing struct {
fail bool
}
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{}, nil
}
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
if m.fail {
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
}
_, pub := btcec.PrivKeyFromBytes(testWalletPrivKey)
return keychain.KeyDescriptor{
PubKey: pub,
}, nil
}
// TestEncryptDecryptPayload tests that given a static key, we're able to // TestEncryptDecryptPayload tests that given a static key, we're able to
// properly decrypt and encrypted payload. We also test that we'll reject a // properly decrypt and encrypted payload. We also test that we'll reject a
// ciphertext that has been modified. // ciphertext that has been modified.
@ -81,15 +52,16 @@ func TestEncryptDecryptPayload(t *testing.T) {
}, },
} }
keyRing := &mockKeyRing{} keyRing := &MockKeyRing{}
for i, payloadCase := range payloadCases { for i, payloadCase := range payloadCases {
var cipherBuffer bytes.Buffer var cipherBuffer bytes.Buffer
encrypter, err := KeyRingEncrypter(keyRing)
require.NoError(t, err)
// First, we'll encrypt the passed payload with our scheme. // First, we'll encrypt the passed payload with our scheme.
payloadReader := bytes.NewBuffer(payloadCase.plaintext) err = encrypter.EncryptPayloadToWriter(
err := encryptPayloadToWriter( payloadCase.plaintext, &cipherBuffer,
*payloadReader, &cipherBuffer, keyRing,
) )
if err != nil { if err != nil {
t.Fatalf("unable encrypt paylaod: %v", err) t.Fatalf("unable encrypt paylaod: %v", err)
@ -107,7 +79,9 @@ func TestEncryptDecryptPayload(t *testing.T) {
cipherBuffer.Write(cipherText) cipherBuffer.Write(cipherText)
} }
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing) plaintext, err := encrypter.DecryptPayloadFromReader(
&cipherBuffer,
)
switch { switch {
// If this was meant to be a valid decryption, but we failed, // If this was meant to be a valid decryption, but we failed,
@ -131,26 +105,13 @@ func TestEncryptDecryptPayload(t *testing.T) {
} }
} }
// TestInvalidKeyEncryption tests that encryption fails if we're unable to // TestInvalidKeyGeneration tests that key generation fails when deriving the
// obtain a valid key. // key fails.
func TestInvalidKeyEncryption(t *testing.T) { func TestInvalidKeyGeneration(t *testing.T) {
t.Parallel() t.Parallel()
var b bytes.Buffer _, err := KeyRingEncrypter(&MockKeyRing{true})
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
if err == nil { if err == nil {
t.Fatalf("expected error due to fail key gen") t.Fatal("expected error due to fail key gen")
}
}
// TestInvalidKeyDecrytion tests that decryption fails if we're unable to
// obtain a valid key.
func TestInvalidKeyDecrytion(t *testing.T) {
t.Parallel()
var b bytes.Buffer
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
if err == nil {
t.Fatalf("expected error due to fail key gen")
} }
} }

40
lnencrypt/test_utils.go Normal file
View File

@ -0,0 +1,40 @@
package lnencrypt
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/keychain"
)
var (
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
)
type MockKeyRing struct {
Fail bool
}
func (m *MockKeyRing) DeriveNextKey(
keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{}, nil
}
func (m *MockKeyRing) DeriveKey(
keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
if m.Fail {
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
}
_, pub := btcec.PrivKeyFromBytes(testWalletPrivKey)
return keychain.KeyDescriptor{
PubKey: pub,
}, nil
}