lnencrypt: Moves the crypto functions in the chanbackup package into its own package called lnencrypt

The functions inside of the crypto.go file in chanbackup (like EncryptPayloadToWriter and DecryptPayloadFromReader) can be used by a lot of things outside of just the chanbackup package. We can't just reference them directly from the chanbackup package because it's likely that it would generate circular dependencies. Therefore we need to move these functions into their own package to be referenced by chanbackup and whatever new functionality that needs them
This commit is contained in:
Graham Krizek
2020-07-11 22:29:00 -05:00
committed by Orbital
parent f3bd2227fa
commit e0fc5bb234
10 changed files with 152 additions and 124 deletions

View File

@ -9,6 +9,7 @@ import (
"path/filepath"
"testing"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require"
)
@ -184,7 +185,7 @@ func assertMultiEqual(t *testing.T, a, b *Multi) {
func TestExtractMulti(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// First, as prep, we'll create a single chan backup, then pack that
// fully into a multi backup.

View File

@ -6,6 +6,7 @@ import (
"io"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire"
)
@ -89,7 +90,12 @@ func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
// With the plaintext multi backup assembled, we'll now encrypt it
// directly to the passed writer.
return encryptPayloadToWriter(multiBackupBuffer, w, keyRing)
e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
return e.EncryptPayloadToWriter(multiBackupBuffer.Bytes(), w)
}
// UnpackFromReader attempts to unpack (decrypt+deserialize) a packed
@ -99,7 +105,11 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
// We'll attempt to read the entire packed backup, and also decrypt it
// using the passed key ring which is expected to be able to derive the
// encryption keys.
plaintextBackup, err := decryptPayloadFromReader(r, keyRing)
e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
plaintextBackup, err := e.DecryptPayloadFromReader(r)
if err != nil {
return err
}

View File

@ -5,6 +5,7 @@ import (
"net"
"testing"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require"
)
@ -27,7 +28,7 @@ func TestMultiPackUnpack(t *testing.T) {
multi.StaticBackups = append(multi.StaticBackups, single)
}
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
versionTestCases := []struct {
// version is the pack/unpack version that we should use to
@ -93,14 +94,17 @@ func TestMultiPackUnpack(t *testing.T) {
)
}
encrypter, err := lnencrypt.KeyRingEncrypter(keyRing)
require.NoError(t, err)
// Next, we'll make a fake packed multi, it'll have an
// unknown version relative to what's implemented atm.
var fakePackedMulti bytes.Buffer
fakeRawMulti := bytes.NewBuffer(
bytes.Repeat([]byte{99}, 20),
)
err := encryptPayloadToWriter(
*fakeRawMulti, &fakePackedMulti, keyRing,
err = encrypter.EncryptPayloadToWriter(
fakeRawMulti.Bytes(), &fakePackedMulti,
)
if err != nil {
t.Fatalf("unable to pack fake multi; %v", err)
@ -124,7 +128,7 @@ func TestMultiPackUnpack(t *testing.T) {
func TestPackedMultiUnpack(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// First, we'll make a new unpacked multi with a random channel.
testChannel, err := genRandomOpenChannelShell()

View File

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require"
)
@ -80,7 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
func TestNewSubSwapperSubscribeFail(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
var swapper mockSwapper
chanNotifier := mockChannelNotifier{
@ -152,7 +153,7 @@ func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
func TestSubSwapperIdempotentStartStop(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
var chanNotifier mockChannelNotifier
@ -181,7 +182,7 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) {
func TestSubSwapperUpdater(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
chanNotifier := newMockChannelNotifier()
swapper := newMockSwapper(keyRing)

View File

@ -7,6 +7,7 @@ import (
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/stretchr/testify/require"
)
@ -49,7 +50,7 @@ func (m *mockPeerConnector) ConnectPeer(node *btcec.PublicKey,
func TestUnpackAndRecoverSingles(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a number of single chan backups that we'll
// shortly back to so we can begin our recovery attempt.
@ -123,7 +124,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) {
}
// If we modify the keyRing, then unpacking should fail.
keyRing.fail = true
keyRing.Fail = true
err = UnpackAndRecoverSingles(
packedBackups, keyRing, &chanRestorer, &peerConnector,
)
@ -139,7 +140,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) {
func TestUnpackAndRecoverMulti(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a number of single chan backups that we'll
// shortly back to so we can begin our recovery attempt.
@ -217,7 +218,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) {
}
// If we modify the keyRing, then unpacking should fail.
keyRing.fail = true
keyRing.Fail = true
err = UnpackAndRecoverMulti(
packedMulti, keyRing, &chanRestorer, &peerConnector,
)

View File

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire"
)
@ -351,7 +352,11 @@ func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
// Finally, we'll encrypt the raw serialized SCB (using the nonce as
// associated data), and write out the ciphertext prepend with the
// nonce that we used to the passed io.Reader.
return encryptPayloadToWriter(rawBytes, w, keyRing)
e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate encrypt key %v", err)
}
return e.EncryptPayloadToWriter(rawBytes.Bytes(), w)
}
// readLocalKeyDesc reads a KeyDescriptor encoded within an unpacked Single.
@ -528,7 +533,11 @@ func (s *Single) Deserialize(r io.Reader) error {
// payload for whatever reason (wrong key, wrong nonce, etc), then this method
// will return an error.
func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
plaintext, err := decryptPayloadFromReader(r, keyRing)
e, err := lnencrypt.KeyRingEncrypter(keyRing)
if err != nil {
return fmt.Errorf("unable to generate key decrypter %v", err)
}
plaintext, err := e.DecryptPayloadFromReader(r)
if err != nil {
return err
}

View File

@ -14,6 +14,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/stretchr/testify/require"
@ -207,7 +208,7 @@ func TestSinglePackUnpack(t *testing.T) {
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
versionTestCases := []struct {
// version is the pack/unpack version that we should use to
@ -312,7 +313,7 @@ func TestSinglePackUnpack(t *testing.T) {
func TestPackedSinglesUnpack(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// To start, we'll create 10 new singles, and them assemble their
// packed forms into a slice.
@ -361,7 +362,7 @@ func TestPackedSinglesUnpack(t *testing.T) {
func TestSinglePackStaticChanBackups(t *testing.T) {
t.Parallel()
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// First, we'll create a set of random single, and along the way,
// create a map that will let us look up each single by its chan point.
@ -407,8 +408,9 @@ func TestSinglePackStaticChanBackups(t *testing.T) {
// If we attempt to pack again, but force the key ring to fail, then
// the entire method should fail.
keyRing.Fail = true
_, err = PackStaticChanBackups(
unpackedSingles, &mockKeyRing{true},
unpackedSingles, &lnencrypt.MockKeyRing{Fail: true},
)
if err == nil {
t.Fatalf("pack attempt should fail")
@ -432,7 +434,7 @@ func TestSingleUnconfirmedChannel(t *testing.T) {
channel.FundingBroadcastHeight = fundingBroadcastHeight
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
keyRing := &mockKeyRing{}
keyRing := &lnencrypt.MockKeyRing{}
// Pack it and then unpack it again to make sure everything is written
// correctly, then check that the block height of the unpacked

View File

@ -1,7 +1,6 @@
package chanbackup
package lnencrypt
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"fmt"
@ -12,8 +11,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
)
// TODO(roasbeef): interface in front of?
// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base
// encryption key used for encrypting all payloads. We use this to then
// derive the actual key that we'll use for encryption. We do this
@ -27,12 +24,32 @@ var baseEncryptionKeyLoc = keychain.KeyLocator{
Index: 0,
}
// genEncryptionKey derives the key that we'll use to encrypt all of our static
// channel backups. The key itself, is the sha2 of a base key that we get from
// the keyring. We derive the key this way as we don't force the HSM (or any
// future abstractions) to be able to derive and know of the cipher that we'll
// use within our protocol.
func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) {
// EncrypterDecrypter is an interface representing an object that encrypts or
// decrypts data.
type EncrypterDecrypter interface {
// EncryptPayloadToWriter attempts to write the set of provided bytes
// into the passed io.Writer in an encrypted form.
EncryptPayloadToWriter([]byte, io.Writer) error
// DecryptPayloadFromReader attempts to decrypt the encrypted bytes
// within the passed io.Reader instance using the key derived from
// the passed keyRing.
DecryptPayloadFromReader(io.Reader) ([]byte, error)
}
// Encrypter is a struct responsible for encrypting and decrypting data.
type Encrypter struct {
encryptionKey []byte
}
// KeyRingEncrypter derives an encryption key to encrypt all our files that are
// written to disk and returns an Encrypter object holding the key.
//
// The key itself, is the sha2 of a base key that we get from the keyring. We
// derive the key this way as we don't force the HSM (or any future
// abstractions) to be able to derive and know of the cipher that we'll use
// within our protocol.
func KeyRingEncrypter(keyRing keychain.KeyRing) (*Encrypter, error) {
// key = SHA256(baseKey)
baseKey, err := keyRing.DeriveKey(
baseEncryptionKeyLoc,
@ -47,33 +64,23 @@ func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) {
// TODO(roasbeef): throw back in ECDH?
return encryptionKey[:], nil
return &Encrypter{
encryptionKey: encryptionKey[:],
}, nil
}
// encryptPayloadToWriter attempts to write the set of bytes contained within
// the passed byes.Buffer into the passed io.Writer in an encrypted form. We
// use a 24-byte chachapoly AEAD instance with a randomized nonce that's
// pre-pended to the final payload and used as associated data in the AEAD. We
// use the passed keyRing to generate the encryption key, see genEncryptionKey
// for further details.
func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
keyRing keychain.KeyRing) error {
// First, we'll derive the key that we'll use to encrypt the payload
// for safe storage without giving away the details of any of our
// channels. The final operation is:
//
// key = SHA256(baseKey)
encryptionKey, err := genEncryptionKey(keyRing)
if err != nil {
return err
}
// EncryptPayloadToWriter attempts to write the set of provided bytes into the
// passed io.Writer in an encrypted form. We use a 24-byte chachapoly AEAD
// instance with a randomized nonce that's pre-pended to the final payload and
// used as associated data in the AEAD.
func (e Encrypter) EncryptPayloadToWriter(payload []byte,
w io.Writer) error {
// Before encryption, we'll initialize our cipher with the target
// encryption key, and also read out our random 24-byte nonce we use
// for encryption. Note that we use NewX, not New, as the latter
// version requires a 12-byte nonce, not a 24-byte nonce.
cipher, err := chacha20poly1305.NewX(encryptionKey)
cipher, err := chacha20poly1305.NewX(e.encryptionKey)
if err != nil {
return err
}
@ -84,7 +91,7 @@ func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
// Finally, we encrypted the final payload, and write out our
// ciphertext with nonce pre-pended.
ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:])
ciphertext := cipher.Seal(nil, nonce[:], payload, nonce[:])
if _, err := w.Write(nonce[:]); err != nil {
return err
@ -96,38 +103,30 @@ func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
return nil
}
// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the
// DecryptPayloadFromReader attempts to decrypt the encrypted bytes within the
// passed io.Reader instance using the key derived from the passed keyRing. For
// further details regarding the key derivation protocol, see the
// genEncryptionKey method.
func decryptPayloadFromReader(payload io.Reader,
keyRing keychain.KeyRing) ([]byte, error) {
// First, we'll re-generate the encryption key that we use for all the
// SCBs.
encryptionKey, err := genEncryptionKey(keyRing)
if err != nil {
return nil, err
}
// KeyRingEncrypter function.
func (e Encrypter) DecryptPayloadFromReader(payload io.Reader) ([]byte,
error) {
// Next, we'll read out the entire blob as we need to isolate the nonce
// from the rest of the ciphertext.
packedBackup, err := ioutil.ReadAll(payload)
packedPayload, err := ioutil.ReadAll(payload)
if err != nil {
return nil, err
}
if len(packedBackup) < chacha20poly1305.NonceSizeX {
if len(packedPayload) < chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("payload size too small, must be at "+
"least %v bytes", chacha20poly1305.NonceSizeX)
}
nonce := packedBackup[:chacha20poly1305.NonceSizeX]
ciphertext := packedBackup[chacha20poly1305.NonceSizeX:]
nonce := packedPayload[:chacha20poly1305.NonceSizeX]
ciphertext := packedPayload[chacha20poly1305.NonceSizeX:]
// Now that we have the cipher text and the nonce separated, we can go
// ahead and decrypt the final blob so we can properly serialized the
// SCB.
cipher, err := chacha20poly1305.NewX(encryptionKey)
// ahead and decrypt the final blob so we can properly serialize.
cipher, err := chacha20poly1305.NewX(e.encryptionKey)
if err != nil {
return nil, err
}

View File

@ -1,41 +1,12 @@
package chanbackup
package lnencrypt
import (
"bytes"
"fmt"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/stretchr/testify/require"
)
var (
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
)
type mockKeyRing struct {
fail bool
}
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{}, nil
}
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
if m.fail {
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
}
_, pub := btcec.PrivKeyFromBytes(testWalletPrivKey)
return keychain.KeyDescriptor{
PubKey: pub,
}, nil
}
// TestEncryptDecryptPayload tests that given a static key, we're able to
// properly decrypt and encrypted payload. We also test that we'll reject a
// ciphertext that has been modified.
@ -81,15 +52,16 @@ func TestEncryptDecryptPayload(t *testing.T) {
},
}
keyRing := &mockKeyRing{}
keyRing := &MockKeyRing{}
for i, payloadCase := range payloadCases {
var cipherBuffer bytes.Buffer
encrypter, err := KeyRingEncrypter(keyRing)
require.NoError(t, err)
// First, we'll encrypt the passed payload with our scheme.
payloadReader := bytes.NewBuffer(payloadCase.plaintext)
err := encryptPayloadToWriter(
*payloadReader, &cipherBuffer, keyRing,
err = encrypter.EncryptPayloadToWriter(
payloadCase.plaintext, &cipherBuffer,
)
if err != nil {
t.Fatalf("unable encrypt paylaod: %v", err)
@ -107,7 +79,9 @@ func TestEncryptDecryptPayload(t *testing.T) {
cipherBuffer.Write(cipherText)
}
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing)
plaintext, err := encrypter.DecryptPayloadFromReader(
&cipherBuffer,
)
switch {
// If this was meant to be a valid decryption, but we failed,
@ -131,26 +105,13 @@ func TestEncryptDecryptPayload(t *testing.T) {
}
}
// TestInvalidKeyEncryption tests that encryption fails if we're unable to
// obtain a valid key.
func TestInvalidKeyEncryption(t *testing.T) {
// TestInvalidKeyGeneration tests that key generation fails when deriving the
// key fails.
func TestInvalidKeyGeneration(t *testing.T) {
t.Parallel()
var b bytes.Buffer
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
_, err := KeyRingEncrypter(&MockKeyRing{true})
if err == nil {
t.Fatalf("expected error due to fail key gen")
}
}
// TestInvalidKeyDecrytion tests that decryption fails if we're unable to
// obtain a valid key.
func TestInvalidKeyDecrytion(t *testing.T) {
t.Parallel()
var b bytes.Buffer
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
if err == nil {
t.Fatalf("expected error due to fail key gen")
t.Fatal("expected error due to fail key gen")
}
}

40
lnencrypt/test_utils.go Normal file
View File

@ -0,0 +1,40 @@
package lnencrypt
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/keychain"
)
var (
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
)
type MockKeyRing struct {
Fail bool
}
func (m *MockKeyRing) DeriveNextKey(
keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{}, nil
}
func (m *MockKeyRing) DeriveKey(
keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
if m.Fail {
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
}
_, pub := btcec.PrivKeyFromBytes(testWalletPrivKey)
return keychain.KeyDescriptor{
PubKey: pub,
}, nil
}