diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 9eac221dd..b7bc9b34a 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -419,12 +419,16 @@ func (o *mockObfuscator) Reextract( return nil } +var fakeHmac = []byte("hmachmachmachmachmachmachmachmac") + func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( lnwire.OpaqueReason, error) { o.failure = failure var b bytes.Buffer + b.Write(fakeHmac) + if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { return nil, err } @@ -436,7 +440,12 @@ func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire. } func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { - return reason + var b bytes.Buffer + b.Write(fakeHmac) + + b.Write(reason) + + return b.Bytes() } // mockDeobfuscator mock implementation of the failure deobfuscator which @@ -447,7 +456,13 @@ func newMockDeobfuscator() ErrorDecrypter { return &mockDeobfuscator{} } -func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) (*ForwardingError, error) { +func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( + *ForwardingError, error) { + + if !bytes.Equal(reason[:32], fakeHmac) { + return nil, errors.New("fake decryption error") + } + reason = reason[32:] r := bytes.NewReader(reason) failure, err := lnwire.DecodeFailure(r, 0) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 461d8875a..4d666b760 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1,7 +1,6 @@ package htlcswitch import ( - "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -23,7 +22,6 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -4088,10 +4086,10 @@ func TestSwitchHoldForward(t *testing.T) { expectedFailure := &lnwire.FailInvalidOnionKey{ OnionSHA256: shaOnionBlob, } - var b bytes.Buffer - require.NoError(t, lnwire.EncodeFailure(&b, expectedFailure, 0)) - assert.Equal(t, lnwire.OpaqueReason(b.Bytes()), failPacket.Reason) + fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason) + require.NoError(t, err) + require.Equal(t, expectedFailure, fwdErr.WireMessage()) assertNumCircuits(t, c.s, 0, 0) @@ -5515,10 +5513,13 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC) require.True(t, ok) - r := bytes.NewReader(failHtlc.Reason) - failure, err := lnwire.DecodeFailure(r, 0) + fwdErr, err := newMockDeobfuscator().DecryptError( + failHtlc.Reason, + ) require.NoError(t, err) + failure := fwdErr.WireMessage() + failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok)