From e24dd2f9e09455db0c6e43e31e4bf99fb0265ca7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 22 Sep 2025 12:09:46 +0200 Subject: [PATCH] lnwire: let DNSAddress implement RecordProducer In preparation for using this type as a TLV record, we let it implement the RecordProducer interface. --- lnwire/dns_addr.go | 71 +++++++++++++++++++++++++ lnwire/dns_addr_test.go | 112 ++++++++++++++++++++++++++++++++++++++++ lnwire/test_message.go | 21 ++++++++ 3 files changed, 204 insertions(+) diff --git a/lnwire/dns_addr.go b/lnwire/dns_addr.go index 87ddcd8cd..ed0c6fade 100644 --- a/lnwire/dns_addr.go +++ b/lnwire/dns_addr.go @@ -1,10 +1,14 @@ package lnwire import ( + "bytes" "errors" "fmt" + "io" "net" "strconv" + + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -86,3 +90,70 @@ func ValidateDNSAddr(hostname string, port uint16) error { return nil } + +// Record returns a TLV record that can be used to encode/decode the DNSAddress. +// +// NOTE: this is part of the tlv.RecordProducer interface. +func (d *DNSAddress) Record() tlv.Record { + sizeFunc := func() uint64 { + // Hostname length + 2 bytes for port. + return uint64(len(d.Hostname) + 2) + } + + return tlv.MakeDynamicRecord( + 0, d, sizeFunc, dnsAddressEncoder, dnsAddressDecoder, + ) +} + +// dnsAddressEncoder is a TLV encoder for DNSAddress. +func dnsAddressEncoder(w io.Writer, val any, _ *[8]byte) error { + if v, ok := val.(*DNSAddress); ok { + var buf bytes.Buffer + + // Write the hostname as raw bytes (no length prefix for TLV). + if _, err := buf.WriteString(v.Hostname); err != nil { + return err + } + + // Write the port as 2 bytes. + err := WriteUint16(&buf, v.Port) + if err != nil { + return err + } + + _, err = w.Write(buf.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "DNSAddress") +} + +// dnsAddressDecoder is a TLV decoder for DNSAddress. +func dnsAddressDecoder(r io.Reader, val any, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*DNSAddress); ok { + if l < 2 { + return fmt.Errorf("DNS address must be at least 2 " + + "bytes") + } + + // Read hostname (all bytes except last 2). + hostnameLen := l - 2 + hostnameBytes := make([]byte, hostnameLen) + if _, err := io.ReadFull(r, hostnameBytes); err != nil { + return err + } + v.Hostname = string(hostnameBytes) + + // Read port (last 2 bytes). + if err := ReadElement(r, &v.Port); err != nil { + return err + } + + return ValidateDNSAddr(v.Hostname, v.Port) + } + + return tlv.NewTypeForDecodingErr(val, "DNSAddress", l, 0) +} diff --git a/lnwire/dns_addr_test.go b/lnwire/dns_addr_test.go index 8cdf2a858..1808fa8dd 100644 --- a/lnwire/dns_addr_test.go +++ b/lnwire/dns_addr_test.go @@ -1,11 +1,14 @@ package lnwire import ( + "bytes" "fmt" "strings" "testing" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" + "pgregory.net/rapid" ) // TestValidateDNSAddr tests hostname and port validation per BOLT #7. @@ -85,3 +88,112 @@ func TestValidateDNSAddr(t *testing.T) { }) } } + +// TestDNSAddressTLVEncoding tests the TLV encoding and decoding of DNSAddress +// structs using the ExtraOpaqueData interface. +func TestDNSAddressTLVEncoding(t *testing.T) { + t.Parallel() + + testDNSAddr := DNSAddress{ + Hostname: "lightning.example.com", + Port: 9000, + } + + var extraData ExtraOpaqueData + require.NoError(t, extraData.PackRecords(&testDNSAddr)) + + var decodedDNSAddr DNSAddress + tlvs, err := extraData.ExtractRecords(&decodedDNSAddr) + require.NoError(t, err) + + require.Contains(t, tlvs, tlv.Type(0)) + require.Equal(t, testDNSAddr, decodedDNSAddr) +} + +// TestDNSAddressRecord tests the TLV Record interface of DNSAddress +// by directly encoding and decoding using the Record method. +func TestDNSAddressRecord(t *testing.T) { + t.Parallel() + + testDNSAddr := DNSAddress{ + Hostname: "lightning.example.com", + Port: 9000, + } + + var buf bytes.Buffer + record := testDNSAddr.Record() + require.NoError(t, record.Encode(&buf)) + + var decodedDNSAddr DNSAddress + decodedRecord := decodedDNSAddr.Record() + require.NoError(t, decodedRecord.Decode(&buf, uint64(buf.Len()))) + + require.Equal(t, testDNSAddr, decodedDNSAddr) +} + +// TestDNSAddressInvalidDecoding tests error cases during TLV decoding. +func TestDNSAddressInvalidDecoding(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + data []byte + errMsg string + }{ + { + name: "too short (only 1 byte)", + data: []byte{0x61}, + errMsg: "DNS address must be at least 2 bytes", + }, + { + name: "empty data", + data: []byte{}, + errMsg: "DNS address must be at least 2 bytes", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var dnsAddr DNSAddress + record := dnsAddr.Record() + + buf := bytes.NewReader(tc.data) + err := record.Decode(buf, uint64(len(tc.data))) + require.Error(t, err) + require.ErrorContains(t, err, tc.errMsg) + }) + } +} + +// TestDNSAddressProperty uses property-based testing to verify that DNSAddress +// TLV encoding and decoding is correct for random DNSAddress values. +func TestDNSAddressProperty(t *testing.T) { + t.Parallel() + + scenario := func(t *rapid.T) { + // Generate a random valid hostname. + hostname := genValidHostname(t) + + // Generate a random port (excluding 0 which is invalid). + port := rapid.Uint16Range(1, 65535).Draw(t, "port") + + dnsAddr := DNSAddress{ + Hostname: hostname, + Port: port, + } + + var buf bytes.Buffer + record := dnsAddr.Record() + err := record.Encode(&buf) + require.NoError(t, err) + + var decodedDNSAddr DNSAddress + decodedRecord := decodedDNSAddr.Record() + err = decodedRecord.Decode(&buf, uint64(buf.Len())) + require.NoError(t, err) + + require.Equal(t, dnsAddr, decodedDNSAddr) + } + + rapid.Check(t, scenario) +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index fd0d1d66a..ff7db2db6 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -1867,3 +1867,24 @@ func (c *Error) RandTestMessage(t *rapid.T) Message { return msg } + +// genValidHostname generates a random valid hostname according to BOLT #7 +// rules. +func genValidHostname(t *rapid.T) string { + // Valid characters: a-z, A-Z, 0-9, -, . + validChars := "abcdefghijklmnopqrstuvwxyzABCDE" + + "FGHIJKLMNOPQRSTUVWXYZ0123456789-." + + // Generate hostname length between 1 and 255 characters. + length := rapid.IntRange(1, 255).Draw(t, "hostname_length") + + hostname := make([]byte, length) + for i := 0; i < length; i++ { + charIndex := rapid.IntRange(0, len(validChars)-1).Draw( + t, fmt.Sprintf("char_%d", i), + ) + hostname[i] = validChars[charIndex] + } + + return string(hostname) +}