lnwire: let DNSAddress implement RecordProducer

In preparation for using this type as a TLV record, we let it implement
the RecordProducer interface.
This commit is contained in:
Elle Mouton
2025-09-22 12:09:46 +02:00
parent e9a4f22dd6
commit e24dd2f9e0
3 changed files with 204 additions and 0 deletions

View File

@@ -1,11 +1,14 @@
package lnwire
import (
"bytes"
"fmt"
"strings"
"testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
// TestValidateDNSAddr tests hostname and port validation per BOLT #7.
@@ -85,3 +88,112 @@ func TestValidateDNSAddr(t *testing.T) {
})
}
}
// TestDNSAddressTLVEncoding tests the TLV encoding and decoding of DNSAddress
// structs using the ExtraOpaqueData interface.
func TestDNSAddressTLVEncoding(t *testing.T) {
t.Parallel()
testDNSAddr := DNSAddress{
Hostname: "lightning.example.com",
Port: 9000,
}
var extraData ExtraOpaqueData
require.NoError(t, extraData.PackRecords(&testDNSAddr))
var decodedDNSAddr DNSAddress
tlvs, err := extraData.ExtractRecords(&decodedDNSAddr)
require.NoError(t, err)
require.Contains(t, tlvs, tlv.Type(0))
require.Equal(t, testDNSAddr, decodedDNSAddr)
}
// TestDNSAddressRecord tests the TLV Record interface of DNSAddress
// by directly encoding and decoding using the Record method.
func TestDNSAddressRecord(t *testing.T) {
t.Parallel()
testDNSAddr := DNSAddress{
Hostname: "lightning.example.com",
Port: 9000,
}
var buf bytes.Buffer
record := testDNSAddr.Record()
require.NoError(t, record.Encode(&buf))
var decodedDNSAddr DNSAddress
decodedRecord := decodedDNSAddr.Record()
require.NoError(t, decodedRecord.Decode(&buf, uint64(buf.Len())))
require.Equal(t, testDNSAddr, decodedDNSAddr)
}
// TestDNSAddressInvalidDecoding tests error cases during TLV decoding.
func TestDNSAddressInvalidDecoding(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
data []byte
errMsg string
}{
{
name: "too short (only 1 byte)",
data: []byte{0x61},
errMsg: "DNS address must be at least 2 bytes",
},
{
name: "empty data",
data: []byte{},
errMsg: "DNS address must be at least 2 bytes",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var dnsAddr DNSAddress
record := dnsAddr.Record()
buf := bytes.NewReader(tc.data)
err := record.Decode(buf, uint64(len(tc.data)))
require.Error(t, err)
require.ErrorContains(t, err, tc.errMsg)
})
}
}
// TestDNSAddressProperty uses property-based testing to verify that DNSAddress
// TLV encoding and decoding is correct for random DNSAddress values.
func TestDNSAddressProperty(t *testing.T) {
t.Parallel()
scenario := func(t *rapid.T) {
// Generate a random valid hostname.
hostname := genValidHostname(t)
// Generate a random port (excluding 0 which is invalid).
port := rapid.Uint16Range(1, 65535).Draw(t, "port")
dnsAddr := DNSAddress{
Hostname: hostname,
Port: port,
}
var buf bytes.Buffer
record := dnsAddr.Record()
err := record.Encode(&buf)
require.NoError(t, err)
var decodedDNSAddr DNSAddress
decodedRecord := decodedDNSAddr.Record()
err = decodedRecord.Decode(&buf, uint64(buf.Len()))
require.NoError(t, err)
require.Equal(t, dnsAddr, decodedDNSAddr)
}
rapid.Check(t, scenario)
}