diff --git a/lnwire/dyn_commit.go b/lnwire/dyn_commit.go index 73b992f61..083da1626 100644 --- a/lnwire/dyn_commit.go +++ b/lnwire/dyn_commit.go @@ -25,10 +25,14 @@ type DynCommit struct { ExtraData ExtraOpaqueData } -// A compile time check to ensure DynAck implements the lnwire.Message +// A compile time check to ensure DynCommit implements the lnwire.Message // interface. var _ Message = (*DynCommit)(nil) +// A compile time check to ensure DynCommit implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*DynCommit)(nil) + // Encode serializes the target DynAck into the passed io.Writer. Serialization // will observe the rules defined by the passed protocol version. // @@ -133,3 +137,10 @@ func (dc *DynCommit) Decode(r io.Reader, _ uint32) error { func (dc *DynCommit) MsgType() MessageType { return MsgDynCommit } + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (dc *DynCommit) SerializedSize() (uint32, error) { + return MessageSerializedSize(dc) +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index b29fa67b6..ec04ae163 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -909,6 +909,91 @@ func (dr *DynReject) RandTestMessage(t *rapid.T) Message { } } +// A compile time check to ensure DynCommit implements the lnwire.TestMessage +// interface. +var _ TestMessage = (*DynCommit)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (dc *DynCommit) RandTestMessage(t *rapid.T) Message { + chanID := RandChannelID(t) + + da := &DynAck{ + ChanID: chanID, + } + + dp := &DynPropose{ + ChanID: chanID, + } + + // Randomly decide which optional fields to include + includeDustLimit := rapid.Bool().Draw(t, "includeDustLimit") + includeMaxValueInFlight := rapid.Bool().Draw( + t, "includeMaxValueInFlight", + ) + includeChannelReserve := rapid.Bool().Draw(t, "includeChannelReserve") + includeCsvDelay := rapid.Bool().Draw(t, "includeCsvDelay") + includeMaxAcceptedHTLCs := rapid.Bool().Draw( + t, "includeMaxAcceptedHTLCs", + ) + includeChannelType := rapid.Bool().Draw(t, "includeChannelType") + + // Generate random values for each included field + if includeDustLimit { + var rec tlv.RecordT[tlv.TlvType0, btcutil.Amount] + val := btcutil.Amount(rapid.Uint32().Draw(t, "dustLimit")) + rec.Val = val + dp.DustLimit = tlv.SomeRecordT(rec) + } + + if includeMaxValueInFlight { + var rec tlv.RecordT[tlv.TlvType2, MilliSatoshi] + val := MilliSatoshi(rapid.Uint64().Draw(t, "maxValueInFlight")) + rec.Val = val + dp.MaxValueInFlight = tlv.SomeRecordT(rec) + } + + if includeChannelReserve { + var rec tlv.RecordT[tlv.TlvType6, btcutil.Amount] + val := btcutil.Amount(rapid.Uint32().Draw(t, "channelReserve")) + rec.Val = val + dp.ChannelReserve = tlv.SomeRecordT(rec) + } + + if includeCsvDelay { + csvDelay := dp.CsvDelay.Zero() + val := rapid.Uint16().Draw(t, "csvDelay") + csvDelay.Val = val + dp.CsvDelay = tlv.SomeRecordT(csvDelay) + } + + if includeMaxAcceptedHTLCs { + maxHtlcs := dp.MaxAcceptedHTLCs.Zero() + maxHtlcs.Val = rapid.Uint16().Draw(t, "maxAcceptedHTLCs") + dp.MaxAcceptedHTLCs = tlv.SomeRecordT(maxHtlcs) + } + + if includeChannelType { + chanType := dp.ChannelType.Zero() + chanType.Val = *RandChannelType(t) + dp.ChannelType = tlv.SomeRecordT(chanType) + } + + var extraData ExtraOpaqueData + randData := RandExtraOpaqueData(t, nil) + if len(randData) > 0 { + extraData = randData + } + + return &DynCommit{ + DynPropose: *dp, + DynAck: *da, + ExtraData: extraData, + } +} + // A compile time check to ensure FundingCreated implements the TestMessage // interface. var _ TestMessage = (*FundingCreated)(nil)