diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 123689902..7f068e526 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -3,31 +3,20 @@ package lnwire import ( "bytes" crand "crypto/rand" - "encoding/binary" "encoding/hex" - "fmt" - "image/color" - "io" "math" "math/rand" "net" - "reflect" "testing" - "testing/quick" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" - "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn/v2" - "github.com/lightningnetwork/lnd/lnwallet/chainfee" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "pgregory.net/rapid" ) var ( @@ -54,106 +43,6 @@ var ( const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -func randLocalNonce(r *rand.Rand) Musig2Nonce { - var nonce Musig2Nonce - _, _ = io.ReadFull(r, nonce[:]) - - return nonce -} - -func someLocalNonce[T tlv.TlvType]( - r *rand.Rand) tlv.OptionalRecordT[T, Musig2Nonce] { - - return tlv.SomeRecordT(tlv.NewRecordT[T, Musig2Nonce]( - randLocalNonce(r), - )) -} - -func randPartialSig(r *rand.Rand) (*PartialSig, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSig{ - Sig: s, - }, nil -} - -func somePartialSig(t *testing.T, - r *rand.Rand) tlv.OptionalRecordT[PartialSigType, PartialSig] { - - sig, err := randPartialSig(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig]( - *sig, - )) -} - -func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { - var sigBytes [32]byte - if _, err := r.Read(sigBytes[:]); err != nil { - return nil, fmt.Errorf("unable to generate sig: %w", err) - } - - var s btcec.ModNScalar - s.SetByteSlice(sigBytes[:]) - - return &PartialSigWithNonce{ - PartialSig: NewPartialSig(s), - Nonce: randLocalNonce(r), - }, nil -} - -func somePartialSigWithNonce(t *testing.T, - r *rand.Rand) OptPartialSigWithNonceTLV { - - sig, err := randPartialSigWithNonce(r) - if err != nil { - t.Fatal(err) - } - - return tlv.SomeRecordT( - tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce]( - *sig, - ), - ) -} - -func randAlias(r *rand.Rand) NodeAlias { - var a NodeAlias - for i := range a { - a[i] = letterBytes[r.Intn(len(letterBytes))] - } - - return a -} - -func randPubKey() (*btcec.PublicKey, error) { - priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } - - return priv.PubKey(), nil -} - -// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. -func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { - pubKeyBytes, err := hex.DecodeString(keyHex) - if err != nil { - return nil, err - } - - return btcec.ParsePubKey(pubKeyBytes) -} - // generateRandomBytes returns a slice of n random bytes. func generateRandomBytes(n int) ([]byte, error) { b := make([]byte, n) @@ -176,140 +65,23 @@ func randRawKey(t *testing.T) [33]byte { return n } -func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { - // Generate size minimum one. Empty scripts should be tested specifically. - size := r.Intn(deliveryAddressMaxSize) + 1 - da := DeliveryAddress(make([]byte, size)) - - _, err := r.Read(da) - return da, err -} - -func randRawFeatureVector(r *rand.Rand) *RawFeatureVector { - featureVec := NewRawFeatureVector() - for i := 0; i < 10000; i++ { - if r.Int31n(2) == 0 { - featureVec.Set(FeatureBit(i)) - } - } - return featureVec -} - -func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [4]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) { - var ip [16]byte - if _, err := r.Read(ip[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - addrIP := net.IP(ip[:]) - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil -} - -func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V2DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { - var serviceID [tor.V3DecodedLen]byte - if _, err := r.Read(serviceID[:]); err != nil { - return nil, err - } - - var port [2]byte - if _, err := r.Read(port[:]); err != nil { - return nil, err - } - - onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) - onionService += tor.OnionSuffix - addrPort := int(binary.BigEndian.Uint16(port[:])) - - return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil -} - -func randOpaqueAddr(r *rand.Rand) (*OpaqueAddrs, error) { - payloadLen := r.Int63n(64) + 1 - payload := make([]byte, payloadLen) - - // The first byte is the address type. So set it to one that we - // definitely don't know about. - payload[0] = math.MaxUint8 - - // Generate random bytes for the rest of the payload. - if _, err := r.Read(payload[1:]); err != nil { - return nil, err - } - - return &OpaqueAddrs{Payload: payload}, nil -} - -func randAddrs(r *rand.Rand) ([]net.Addr, error) { - tcp4Addr, err := randTCP4Addr(r) +func randPubKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey() if err != nil { return nil, err } - tcp6Addr, err := randTCP6Addr(r) + return priv.PubKey(), nil +} + +// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. +func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { + pubKeyBytes, err := hex.DecodeString(keyHex) if err != nil { return nil, err } - v2OnionAddr, err := randV2OnionAddr(r) - if err != nil { - return nil, err - } - - v3OnionAddr, err := randV3OnionAddr(r) - if err != nil { - return nil, err - } - - opaqueAddrs, err := randOpaqueAddr(r) - if err != nil { - return nil, err - } - - return []net.Addr{ - tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr, opaqueAddrs, - }, nil + return btcec.ParsePubKey(pubKeyBytes) } // TestChanUpdateChanFlags ensures that converting the ChanUpdateChanFlags and @@ -418,1558 +190,64 @@ func TestEmptyMessageUnknownType(t *testing.T) { } } -// randCustomRecords generates a random set of custom records for testing. -func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { - var ( - customRecords = CustomRecords{} - - // We'll generate a random number of records, between 1 and 10. - numRecords = r.Intn(9) + 1 - ) - - // For each record, we'll generate a random key and value. - for i := 0; i < numRecords; i++ { - // Keys must be equal to or greater than - // MinCustomRecordsTlvType. - keyOffset := uint64(r.Intn(100)) - key := MinCustomRecordsTlvType + keyOffset - - // Values are byte slices of any length. - value := make([]byte, r.Intn(10)) - _, err := r.Read(value) - require.NoError(t, err) - - customRecords[key] = value - } - - // Validate the custom records as a sanity check. - err := customRecords.Validate() - require.NoError(t, err) - - return customRecords -} - -// TestLightningWireProtocol uses the testing/quick package to create a series -// of fuzz tests to attempt to break a primary scenario which is implemented as -// property based testing scenario. +// TestLightningWireProtocol uses the rapid property-based testing framework to +// verify that all message types can be serialized and deserialized correctly. func TestLightningWireProtocol(t *testing.T) { t.Parallel() - // mainScenario is the primary test that will programmatically be - // executed for all registered wire messages. The quick-checker within - // testing/quick will attempt to find an input to this function, s.t - // the function returns false, if so then we've found an input that - // violates our model of the system. - mainScenario := func(msg Message) bool { - // Give a new message, we'll serialize the message into a new - // bytes buffer. - var b bytes.Buffer - if _, err := WriteMessage(&b, msg, 0); err != nil { - t.Fatalf("unable to write msg: %v", err) - return false + for msgType := MessageType(0); msgType < MsgEnd; msgType++ { + // If MakeEmptyMessage returns an error, then this isn't yet a + // used message type. + if _, err := MakeEmptyMessage(msgType); err != nil { + continue } - // Next, we'll ensure that the serialized payload (subtracting - // the 2 bytes for the message type) is _below_ the specified - // max payload size for this message. - payloadLen := uint32(b.Len()) - 2 - if payloadLen > MaxMsgBody { - t.Fatalf("msg payload constraint violated: %v > %v", - payloadLen, MaxMsgBody) - return false - } + t.Run(msgType.String(), rapid.MakeCheck(func(t *rapid.T) { + // Create an empty message of the given type. + m, err := MakeEmptyMessage(msgType) - // Finally, we'll deserialize the message from the written - // buffer, and finally assert that the messages are equal. - newMsg, err := ReadMessage(&b, 0) - if err != nil { - t.Fatalf("unable to read msg: %v", err) - return false - } - if !assert.Equalf(t, msg, newMsg, "message mismatch") { - return false - } - - return true - } - - // customTypeGen is a map of functions that are able to randomly - // generate a given type. These functions are needed for types which - // are too complex for the testing/quick package to automatically - // generate. - customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){ - MsgStfu: func(v []reflect.Value, r *rand.Rand) { - req := Stfu{} - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate ChanID: %v", err) + // An error means this isn't a valid message type, so we + // skip it. + if err != nil { + return } - // 1/2 chance of being initiator - req.Initiator = r.Intn(2) == 1 - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgInit: func(v []reflect.Value, r *rand.Rand) { - req := NewInitMessage( - randRawFeatureVector(r), - randRawFeatureVector(r), + // The message must support the message type interface. + testMsg, ok := m.(TestMessage) + require.True( + t, ok, "message %v doesn't support TestMessage", + msgType, ) - v[0] = reflect.ValueOf(*req) - }, - MsgOpenChannel: func(v []reflect.Value, r *rand.Rand) { - req := OpenChannel{ - FundingAmount: btcutil.Amount(r.Int63()), - PushAmount: MilliSatoshi(r.Int63()), - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - HtlcMinimum: MilliSatoshi(r.Int31()), - FeePerKiloWeight: uint32(r.Int63()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - ChannelFlags: FundingFlag(uint8(r.Int31())), - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgAcceptChannel: func(v []reflect.Value, r *rand.Rand) { - req := AcceptChannel{ - DustLimit: btcutil.Amount(r.Int63()), - MaxValueInFlight: MilliSatoshi(r.Int63()), - ChannelReserve: btcutil.Amount(r.Int63()), - MinAcceptDepth: uint32(r.Int31()), - HtlcMinimum: MilliSatoshi(r.Int31()), - CsvDelay: uint16(r.Int31()), - MaxAcceptedHTLCs: uint16(r.Int31()), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - var err error - req.FundingKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.RevocationPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.PaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.DelayedPaymentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.HtlcPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.FirstCommitmentPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 1/2 chance empty TLV records. - if r.Intn(2) == 0 { - req.UpfrontShutdownScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery address: %v", err) - return - } - - req.ChannelType = new(ChannelType) - *req.ChannelType = ChannelType(*randRawFeatureVector(r)) - - req.LeaseExpiry = new(LeaseExpiry) - *req.LeaseExpiry = LeaseExpiry(1337) - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } else { - req.UpfrontShutdownScript = []byte{} - } - - // 1/2 chance additional TLV data. - if r.Intn(2) == 0 { - req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{ - ExtraData: make([]byte, 0), - } - - if _, err := r.Read(req.PendingChannelID[:]); err != nil { - t.Fatalf("unable to generate pending chan id: %v", err) - return - } - - if _, err := r.Read(req.FundingPoint.Hash[:]); err != nil { - t.Fatalf("unable to generate hash: %v", err) - return - } - req.FundingPoint.Index = uint32(r.Int31()) % math.MaxUint16 - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgFundingSigned: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := FundingSigned{ - ChanID: ChannelID(c), - ExtraData: make([]byte, 0), - } - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // 1/2 chance to attach a partial sig. - if r.Intn(2) == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - require.NoError(t, err) - - pubKey, err := randPubKey() - require.NoError(t, err) - - req := NewChannelReady(c, pubKey) - - if r.Int31()%2 == 0 { - scid := NewShortChanIDFromInt(uint64(r.Int63())) - req.AliasScid = &scid - - //nolint:ll - req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - if r.Int31()%2 == 0 { - nodeNonce := tlv.ZeroRecordT[ - tlv.TlvType0, Musig2Nonce, - ]() - nodeNonce.Val = randLocalNonce(r) - req.AnnouncementNodeNonce = tlv.SomeRecordT( - nodeNonce, - ) - - btcNonce := tlv.ZeroRecordT[ - tlv.TlvType2, Musig2Nonce, - ]() - btcNonce.Val = randLocalNonce(r) - req.AnnouncementBitcoinNonce = tlv.SomeRecordT( - btcNonce, - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgShutdown: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - shutdownAddr, err := randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - req := Shutdown{ - ChannelID: ChannelID(c), - Address: shutdownAddr, - } - - if r.Int31()%2 == 0 { - //nolint:ll - req.ShutdownNonce = someLocalNonce[ShutdownNonceType](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { - req := ClosingSigned{ - FeeSatoshis: btcutil.Amount(r.Int63()), - ExtraData: make([]byte, 0), - } - var err error - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSig(t, r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgDynPropose: func(v []reflect.Value, r *rand.Rand) { - var dp DynPropose - rand.Read(dp.ChanID[:]) - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.DustLimit = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := MilliSatoshi(rand.Uint32()) - dp.MaxValueInFlight = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := btcutil.Amount(rand.Uint32()) - dp.ChannelReserve = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.CsvDelay = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := uint16(rand.Uint32()) - dp.MaxAcceptedHTLCs = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v, _ := btcec.NewPrivateKey() - dp.FundingKey = fn.Some(*v.PubKey()) - } - - if rand.Uint32()%2 == 0 { - v := ChannelType(*NewRawFeatureVector()) - dp.ChannelType = fn.Some(v) - } - - if rand.Uint32()%2 == 0 { - v := chainfee.SatPerKWeight(rand.Uint32()) - dp.KickoffFeerate = fn.Some(v) - } - - v[0] = reflect.ValueOf(dp) - }, - MsgDynReject: func(v []reflect.Value, r *rand.Rand) { - var dr DynReject - rand.Read(dr.ChanID[:]) - - features := NewRawFeatureVector() - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPDustLimitSatoshis)) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPMaxHtlcValueInFlightMsat), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set( - FeatureBit(DPChannelReserveSatoshis), - ) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPToSelfDelay)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPMaxAcceptedHtlcs)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPFundingPubkey)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPChannelType)) - } - - if rand.Uint32()%2 == 0 { - features.Set(FeatureBit(DPKickoffFeerate)) - } - dr.UpdateRejections = *features - - v[0] = reflect.ValueOf(dr) - }, - MsgDynAck: func(v []reflect.Value, r *rand.Rand) { - var da DynAck - - rand.Read(da.ChanID[:]) - if rand.Uint32()%2 == 0 { - var nonce Musig2Nonce - rand.Read(nonce[:]) - da.LocalNonce = fn.Some(nonce) - } - - v[0] = reflect.ValueOf(da) - }, - MsgKickoffSig: func(v []reflect.Value, r *rand.Rand) { - ks := KickoffSig{ - ExtraData: make([]byte, 0), - } - - rand.Read(ks.ChanID[:]) - rand.Read(ks.Signature.bytes[:]) - - v[0] = reflect.ValueOf(ks) - }, - MsgCommitSig: func(v []reflect.Value, r *rand.Rand) { - req := NewCommitSig() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - var err error - req.CommitSig, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - // Only create the slice if there will be any signatures - // in it to prevent false positive test failures due to - // an empty slice versus a nil slice. - numSigs := uint16(r.Int31n(500)) - if numSigs > 0 { - req.HtlcSigs = make([]Sig, numSigs) - } - for i := 0; i < int(numSigs); i++ { - req.HtlcSigs[i], err = NewSigFromSignature( - testSig, - ) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - } - - req.CustomRecords = randCustomRecords(t, r) - - // 50/50 chance to attach a partial sig. - if r.Int31()%2 == 0 { - req.PartialSig = somePartialSigWithNonce(t, r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgRevokeAndAck: func(v []reflect.Value, r *rand.Rand) { - req := NewRevokeAndAck() - if _, err := r.Read(req.ChanID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - if _, err := r.Read(req.Revocation[:]); err != nil { - t.Fatalf("unable to generate bytes: %v", err) - return - } - var err error - req.NextRevocationKey, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - // 50/50 chance to attach a local nonce. - if r.Int31()%2 == 0 { - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := ChannelAnnouncement1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - NodeID1: randRawKey(t), - NodeID2: randRawKey(t), - BitcoinKey1: randRawKey(t), - BitcoinKey2: randRawKey(t), - Features: randRawFeatureVector(r), - ExtraOpaqueData: make([]byte, 0), - } - req.NodeSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.NodeSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig1, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - req.BitcoinSig2, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var err error - req := NodeAnnouncement{ - NodeID: randRawKey(t), - Features: randRawFeatureVector(r), - Timestamp: uint32(r.Int31()), - Alias: randAlias(r), - RGBColor: color.RGBA{ - R: uint8(r.Int31()), - G: uint8(r.Int31()), - B: uint8(r.Int31()), - }, - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.Addresses, err = randAddrs(r) - if err != nil { - t.Fatalf("unable to generate addresses: %v", err) - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate: func(v []reflect.Value, r *rand.Rand) { - var err error - - msgFlags := ChanUpdateMsgFlags(r.Int31()) - maxHtlc := MilliSatoshi(r.Int63()) - - // We make the max_htlc field zero if it is not flagged - // as being part of the ChannelUpdate, to pass - // serialization tests, as it will be ignored if the bit - // is not set. - if msgFlags&ChanUpdateRequiredMaxHtlc == 0 { - maxHtlc = 0 - } - - req := ChannelUpdate1{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - Timestamp: uint32(r.Int31()), - MessageFlags: msgFlags, - ChannelFlags: ChanUpdateChanFlags(r.Int31()), - TimeLockDelta: uint16(r.Int31()), - HtlcMinimumMsat: MilliSatoshi(r.Int63()), - HtlcMaximumMsat: maxHtlc, - BaseFee: uint32(r.Int31()), - FeeRate: uint32(r.Int31()), - ExtraOpaqueData: make([]byte, 0), - } - req.Signature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to generate chain hash: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { - var err error - req := AnnounceSignatures1{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - ExtraOpaqueData: make([]byte, 0), - } - - req.NodeSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - req.BitcoinSignature, err = NewSigFromSignature(testSig) - if err != nil { - t.Fatalf("unable to parse sig: %v", err) - return - } - - if _, err := r.Read(req.ChannelID[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make([]byte, numExtraBytes) - _, err := r.Read(req.ExtraOpaqueData[:]) - if err != nil { - t.Fatalf("unable to generate opaque "+ - "bytes: %v", err) - return - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelReestablish: func(v []reflect.Value, r *rand.Rand) { - req := ChannelReestablish{ - NextLocalCommitHeight: uint64(r.Int63()), - RemoteCommitTailHeight: uint64(r.Int63()), - ExtraData: make([]byte, 0), - } - - // With a 50/50 probability, we'll include the - // additional fields so we can test our ability to - // properly parse, and write out the optional fields. - if r.Int()%2 == 0 { - _, err := r.Read(req.LastRemoteCommitSecret[:]) - if err != nil { - t.Fatalf("unable to read commit secret: %v", err) - return - } - - req.LocalUnrevokedCommitPoint, err = randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - - //nolint:ll - req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) - } - - v[0] = reflect.ValueOf(req) - }, - MsgGossipTimestampRange: func(v []reflect.Value, r *rand.Rand) { - req := GossipTimestampRange{ - FirstTimestamp: rand.Uint32(), - TimestampRange: rand.Uint32(), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // Sometimes add a block range. - if r.Int31()%2 == 0 { - firstBlock := tlv.ZeroRecordT[ - tlv.TlvType2, uint32, - ]() - firstBlock.Val = rand.Uint32() - req.FirstBlockHeight = tlv.SomeRecordT( - firstBlock, - ) - - blockRange := tlv.ZeroRecordT[ - tlv.TlvType4, uint32, - ]() - blockRange.Val = rand.Uint32() - req.BlockRange = tlv.SomeRecordT(blockRange) - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{ - ExtraData: make([]byte, 0), - } - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - numChanIDs := rand.Int31n(5000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - v[0] = reflect.ValueOf(req) - }, - MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := ReplyChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - if _, err := rand.Read(req.ChainHash[:]); err != nil { - t.Fatalf("unable to read chain hash: %v", err) - return - } - - req.Complete = uint8(r.Int31n(2)) - - // With a 50/50 change, we'll either use zlib encoding, - // or regular encoding. - if r.Int31()%2 == 0 { - req.EncodingType = EncodingSortedZlib - } else { - req.EncodingType = EncodingSortedPlain - } - - numChanIDs := rand.Int31n(4000) - for i := int32(0); i < numChanIDs; i++ { - req.ShortChanIDs = append(req.ShortChanIDs, - NewShortChanIDFromInt(uint64(r.Int63()))) - } - - // With a 50/50 chance, add some timestamps. - if r.Int31()%2 == 0 { - for i := int32(0); i < numChanIDs; i++ { - timestamps := ChanUpdateTimestamps{ - Timestamp1: rand.Uint32(), - Timestamp2: rand.Uint32(), - } - req.Timestamps = append( - req.Timestamps, timestamps, - ) - } - } - - v[0] = reflect.ValueOf(req) - }, - MsgQueryChannelRange: func(v []reflect.Value, r *rand.Rand) { - req := QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - ExtraData: make([]byte, 0), - } - - _, err := rand.Read(req.ChainHash[:]) - require.NoError(t, err) - - // With a 50/50 change, we'll set a query option. - if r.Int31()%2 == 0 { - req.QueryOptions = NewTimestampQueryOption() - } - - v[0] = reflect.ValueOf(req) - }, - MsgPing: func(v []reflect.Value, r *rand.Rand) { - // We use a special message generator here to ensure we - // don't generate ping messages that are too large, - // which'll cause the test to fail. - // - // We'll allow the test to generate padding bytes up to - // the max message limit, factoring in the 2 bytes for - // the num pong bytes and 2 bytes for encoding the - // length of the padding bytes. - paddingBytes := make([]byte, rand.Intn(MaxMsgBody-3)) - req := Ping{ - NumPongBytes: uint16(r.Intn(MaxPongBytes + 1)), - PaddingBytes: paddingBytes, - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingComplete: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", - err) - return - } - - req := ClosingComplete{ - ChannelID: ChannelID(c), - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - ClosingSigs: ClosingSigs{}, - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgClosingSig: func(v []reflect.Value, r *rand.Rand) { - var c [32]byte - _, err := r.Read(c[:]) - if err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } - - req := ClosingSig{ - ChannelID: ChannelID(c), - ClosingSigs: ClosingSigs{}, - FeeSatoshis: btcutil.Amount(r.Int63()), - LockTime: uint32(r.Int63()), - } - req.CloserScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - req.CloseeScript, err = randDeliveryAddress(r) - if err != nil { - t.Fatalf("unable to generate delivery "+ - "address: %v", err) - return - } - - if r.Intn(2) == 0 { - sig := req.CloserNoClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserNoClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.NoCloserClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.NoCloserClosee = tlv.SomeRecordT(sig) - } - if r.Intn(2) == 0 { - sig := req.CloserAndClosee.Zero() - _, err := r.Read(sig.Val.bytes[:]) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } - - req.CloserAndClosee = tlv.SomeRecordT(sig) - } - - v[0] = reflect.ValueOf(req) - }, - MsgUpdateAddHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateAddHTLC{ - ID: r.Uint64(), - Amount: MilliSatoshi(r.Uint64()), - Expiry: r.Uint32(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentHash[:]) - require.NoError(t, err) - - _, err = r.Read(req.OnionBlob[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate a blinding point 50% of the time, since not - // all update adds will use route blinding. - if r.Int31()%2 == 0 { - pubkey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", - err) - - return - } - - req.BlindingPoint = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType0]( - pubkey, - ), - ) - } - - v[0] = reflect.ValueOf(*req) - }, - MsgUpdateFulfillHTLC: func(v []reflect.Value, r *rand.Rand) { - req := &UpdateFulfillHTLC{ - ID: r.Uint64(), - } - - _, err := r.Read(req.ChanID[:]) - require.NoError(t, err) - - _, err = r.Read(req.PaymentPreimage[:]) - require.NoError(t, err) - - req.CustomRecords = randCustomRecords(t, r) - - // Generate some random TLV records 50% of the time. - if r.Int31()%2 == 0 { - req.ExtraData = []byte{ - 0x01, 0x03, 1, 2, 3, - 0x02, 0x03, 4, 5, 6, - } - } - - v[0] = reflect.ValueOf(*req) - }, - MsgAnnounceSignatures2: func(v []reflect.Value, - r *rand.Rand) { - - req := AnnounceSignatures2{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - ExtraOpaqueData: make([]byte, 0), - } - - _, err := r.Read(req.ChannelID[:]) - require.NoError(t, err) - - partialSig, err := randPartialSig(r) - require.NoError(t, err) - - req.PartialSignature = *partialSig - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelAnnouncement2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Use the RandTestMessage method to create a randomized + // message. + msg := testMsg.RandTestMessage(t) + + // Serialize the message to a buffer. + var b bytes.Buffer + writtenBytes, err := WriteMessage(&b, msg, 0) + require.NoError(t, err, "unable to write msg") + + // Check that the serialized payload is below the max + // payload size, accounting for the message type size. + payloadLen := uint32(writtenBytes) - 2 + require.LessOrEqual( + t, payloadLen, uint32(MaxMsgBody), + "msg payload constraint violated: %v > %v", + payloadLen, MaxMsgBody, ) - req.Capacity.Val = rand.Uint64() - req.Features.Val = *randRawFeatureVector(r) + // Deserialize the message from the buffer. + newMsg, err := ReadMessage(&b, 0) + require.NoError(t, err, "unable to read msg") - req.NodeID1.Val = randRawKey(t) - req.NodeID2.Val = randRawKey(t) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes set the bitcoin keys. - if r.Int31()%2 == 0 { - btcKey1 := tlv.ZeroRecordT[ - tlv.TlvType12, [33]byte, - ]() - btcKey1.Val = randRawKey(t) - req.BitcoinKey1 = tlv.SomeRecordT(btcKey1) - - btcKey2 := tlv.ZeroRecordT[ - tlv.TlvType14, [33]byte, - ]() - btcKey2.Val = randRawKey(t) - req.BitcoinKey2 = tlv.SomeRecordT(btcKey2) - - // Occasionally also set the merkle root hash. - if r.Int31()%2 == 0 { - hash := tlv.ZeroRecordT[ - tlv.TlvType16, [32]byte, - ]() - - _, err := r.Read(hash.Val[:]) - require.NoError(t, err) - - req.MerkleRootHash = tlv.SomeRecordT( - hash, - ) - } - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, - MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelUpdate2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } - - req.ShortChannelID.Val = NewShortChanIDFromInt( - uint64(r.Int63()), + // Verify the deserialized message matches the original. + require.Equal( + t, msg, newMsg, + "message mismatch for type %s", msgType, ) - req.BlockHeight.Val = r.Uint32() - req.HTLCMaximumMsat.Val = MilliSatoshi(r.Uint64()) - - // Sometimes set chain hash to bitcoin mainnet genesis - // hash. - req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if r.Int31()%2 == 0 { - _, err := r.Read(req.ChainHash.Val[:]) - require.NoError(t, err) - } - - // Sometimes use default htlc min msat. - req.HTLCMinimumMsat.Val = defaultHtlcMinMsat - if r.Int31()%2 == 0 { - req.HTLCMinimumMsat.Val = MilliSatoshi( - r.Uint64(), - ) - } - - // Sometimes set the cltv expiry delta to the default. - req.CLTVExpiryDelta.Val = defaultCltvExpiryDelta - if r.Int31()%2 == 0 { - req.CLTVExpiryDelta.Val = uint16(r.Int31()) - } - - // Sometimes use default fee base. - req.FeeBaseMsat.Val = defaultFeeBaseMsat - if r.Int31()%2 == 0 { - req.FeeBaseMsat.Val = r.Uint32() - } - - // Sometimes use default proportional fee. - req.FeeProportionalMillionths.Val = - defaultFeeProportionalMillionths - if r.Int31()%2 == 0 { - req.FeeProportionalMillionths.Val = r.Uint32() - } - - // Alternate between the two direction possibilities. - if r.Int31()%2 == 0 { - req.SecondPeer = tlv.SomeRecordT( - tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean](), //nolint:ll - ) - } - - // Sometimes set the incoming disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableIncoming - } - - // Sometimes set the outgoing disabled flag. - if r.Int31()%2 == 0 { - req.DisabledFlags.Val |= - ChanUpdateDisableOutgoing - } - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - - v[0] = reflect.ValueOf(req) - }, + })) } - - // With the above types defined, we'll now generate a slice of - // scenarios to feed into quick.Check. The function scans in input - // space of the target function under test, so we'll need to create a - // series of wrapper functions to force it to iterate over the target - // types, but re-use the mainScenario defined above. - tests := []struct { - msgType MessageType - scenario interface{} - }{ - { - msgType: MsgStfu, - scenario: func(m Stfu) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgInit, - scenario: func(m Init) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgWarning, - scenario: func(m Warning) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgError, - scenario: func(m Error) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPing, - scenario: func(m Ping) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgPong, - scenario: func(m Pong) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgOpenChannel, - scenario: func(m OpenChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAcceptChannel, - scenario: func(m AcceptChannel) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingCreated, - scenario: func(m FundingCreated) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgFundingSigned, - scenario: func(m FundingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReady, - scenario: func(m ChannelReady) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgShutdown, - scenario: func(m Shutdown) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSigned, - scenario: func(m ClosingSigned) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynPropose, - scenario: func(m DynPropose) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynReject, - scenario: func(m DynReject) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgDynAck, - scenario: func(m DynAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgKickoffSig, - scenario: func(m KickoffSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateAddHTLC, - scenario: func(m UpdateAddHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFulfillHTLC, - scenario: func(m UpdateFulfillHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFailHTLC, - scenario: func(m UpdateFailHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgCommitSig, - scenario: func(m CommitSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgRevokeAndAck, - scenario: func(m RevokeAndAck) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgUpdateFee, - scenario: func(m UpdateFee) bool { - return mainScenario(&m) - }, - }, - { - - msgType: MsgUpdateFailMalformedHTLC, - scenario: func(m UpdateFailMalformedHTLC) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelReestablish, - scenario: func(m ChannelReestablish) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement, - scenario: func(m ChannelAnnouncement1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgNodeAnnouncement, - scenario: func(m NodeAnnouncement) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate, - scenario: func(m ChannelUpdate1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures, - scenario: func(m AnnounceSignatures1) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgGossipTimestampRange, - scenario: func(m GossipTimestampRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryShortChanIDs, - scenario: func(m QueryShortChanIDs) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyShortChanIDsEnd, - scenario: func(m ReplyShortChanIDsEnd) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgQueryChannelRange, - scenario: func(m QueryChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgReplyChannelRange, - scenario: func(m ReplyChannelRange) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingComplete, - scenario: func(m ClosingComplete) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgClosingSig, - scenario: func(m ClosingSig) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgAnnounceSignatures2, - scenario: func(m AnnounceSignatures2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelAnnouncement2, - scenario: func(m ChannelAnnouncement2) bool { - return mainScenario(&m) - }, - }, - { - msgType: MsgChannelUpdate2, - scenario: func(m ChannelUpdate2) bool { - return mainScenario(&m) - }, - }, - } - for _, test := range tests { - t.Run(test.msgType.String(), func(t *testing.T) { - var config *quick.Config - - // If the type defined is within the custom type gen - // map above, then we'll modify the default config to - // use this Value function that knows how to generate - // the proper types. - if valueGen, ok := customTypeGen[test.msgType]; ok { - config = &quick.Config{ - Values: valueGen, - } - } - - t.Logf("Running fuzz tests for msgType=%v", - test.msgType) - - err := quick.Check(test.scenario, config) - if err != nil { - t.Fatalf("fuzz checks for msg=%v failed: %v", - test.msgType, err) - } - }) - } - } func init() {