mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-10 14:17:56 +01:00
lnwire: let DNSAddress implement RecordProducer
In preparation for using this type as a TLV record, we let it implement the RecordProducer interface.
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
package lnwire
|
package lnwire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -86,3 +90,70 @@ func ValidateDNSAddr(hostname string, port uint16) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record returns a TLV record that can be used to encode/decode the DNSAddress.
|
||||||
|
//
|
||||||
|
// NOTE: this is part of the tlv.RecordProducer interface.
|
||||||
|
func (d *DNSAddress) Record() tlv.Record {
|
||||||
|
sizeFunc := func() uint64 {
|
||||||
|
// Hostname length + 2 bytes for port.
|
||||||
|
return uint64(len(d.Hostname) + 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.MakeDynamicRecord(
|
||||||
|
0, d, sizeFunc, dnsAddressEncoder, dnsAddressDecoder,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsAddressEncoder is a TLV encoder for DNSAddress.
|
||||||
|
func dnsAddressEncoder(w io.Writer, val any, _ *[8]byte) error {
|
||||||
|
if v, ok := val.(*DNSAddress); ok {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
// Write the hostname as raw bytes (no length prefix for TLV).
|
||||||
|
if _, err := buf.WriteString(v.Hostname); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the port as 2 bytes.
|
||||||
|
err := WriteUint16(&buf, v.Port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = w.Write(buf.Bytes())
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.NewTypeForEncodingErr(val, "DNSAddress")
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsAddressDecoder is a TLV decoder for DNSAddress.
|
||||||
|
func dnsAddressDecoder(r io.Reader, val any, _ *[8]byte,
|
||||||
|
l uint64) error {
|
||||||
|
|
||||||
|
if v, ok := val.(*DNSAddress); ok {
|
||||||
|
if l < 2 {
|
||||||
|
return fmt.Errorf("DNS address must be at least 2 " +
|
||||||
|
"bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read hostname (all bytes except last 2).
|
||||||
|
hostnameLen := l - 2
|
||||||
|
hostnameBytes := make([]byte, hostnameLen)
|
||||||
|
if _, err := io.ReadFull(r, hostnameBytes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v.Hostname = string(hostnameBytes)
|
||||||
|
|
||||||
|
// Read port (last 2 bytes).
|
||||||
|
if err := ReadElement(r, &v.Port); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateDNSAddr(v.Hostname, v.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlv.NewTypeForDecodingErr(val, "DNSAddress", l, 0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package lnwire
|
package lnwire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"pgregory.net/rapid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestValidateDNSAddr tests hostname and port validation per BOLT #7.
|
// TestValidateDNSAddr tests hostname and port validation per BOLT #7.
|
||||||
@@ -85,3 +88,112 @@ func TestValidateDNSAddr(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDNSAddressTLVEncoding tests the TLV encoding and decoding of DNSAddress
|
||||||
|
// structs using the ExtraOpaqueData interface.
|
||||||
|
func TestDNSAddressTLVEncoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testDNSAddr := DNSAddress{
|
||||||
|
Hostname: "lightning.example.com",
|
||||||
|
Port: 9000,
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraData ExtraOpaqueData
|
||||||
|
require.NoError(t, extraData.PackRecords(&testDNSAddr))
|
||||||
|
|
||||||
|
var decodedDNSAddr DNSAddress
|
||||||
|
tlvs, err := extraData.ExtractRecords(&decodedDNSAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Contains(t, tlvs, tlv.Type(0))
|
||||||
|
require.Equal(t, testDNSAddr, decodedDNSAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSAddressRecord tests the TLV Record interface of DNSAddress
|
||||||
|
// by directly encoding and decoding using the Record method.
|
||||||
|
func TestDNSAddressRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testDNSAddr := DNSAddress{
|
||||||
|
Hostname: "lightning.example.com",
|
||||||
|
Port: 9000,
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
record := testDNSAddr.Record()
|
||||||
|
require.NoError(t, record.Encode(&buf))
|
||||||
|
|
||||||
|
var decodedDNSAddr DNSAddress
|
||||||
|
decodedRecord := decodedDNSAddr.Record()
|
||||||
|
require.NoError(t, decodedRecord.Decode(&buf, uint64(buf.Len())))
|
||||||
|
|
||||||
|
require.Equal(t, testDNSAddr, decodedDNSAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSAddressInvalidDecoding tests error cases during TLV decoding.
|
||||||
|
func TestDNSAddressInvalidDecoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "too short (only 1 byte)",
|
||||||
|
data: []byte{0x61},
|
||||||
|
errMsg: "DNS address must be at least 2 bytes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty data",
|
||||||
|
data: []byte{},
|
||||||
|
errMsg: "DNS address must be at least 2 bytes",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var dnsAddr DNSAddress
|
||||||
|
record := dnsAddr.Record()
|
||||||
|
|
||||||
|
buf := bytes.NewReader(tc.data)
|
||||||
|
err := record.Decode(buf, uint64(len(tc.data)))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, tc.errMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSAddressProperty uses property-based testing to verify that DNSAddress
|
||||||
|
// TLV encoding and decoding is correct for random DNSAddress values.
|
||||||
|
func TestDNSAddressProperty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scenario := func(t *rapid.T) {
|
||||||
|
// Generate a random valid hostname.
|
||||||
|
hostname := genValidHostname(t)
|
||||||
|
|
||||||
|
// Generate a random port (excluding 0 which is invalid).
|
||||||
|
port := rapid.Uint16Range(1, 65535).Draw(t, "port")
|
||||||
|
|
||||||
|
dnsAddr := DNSAddress{
|
||||||
|
Hostname: hostname,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
record := dnsAddr.Record()
|
||||||
|
err := record.Encode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decodedDNSAddr DNSAddress
|
||||||
|
decodedRecord := decodedDNSAddr.Record()
|
||||||
|
err = decodedRecord.Decode(&buf, uint64(buf.Len()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, dnsAddr, decodedDNSAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
rapid.Check(t, scenario)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1867,3 +1867,24 @@ func (c *Error) RandTestMessage(t *rapid.T) Message {
|
|||||||
|
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// genValidHostname generates a random valid hostname according to BOLT #7
|
||||||
|
// rules.
|
||||||
|
func genValidHostname(t *rapid.T) string {
|
||||||
|
// Valid characters: a-z, A-Z, 0-9, -, .
|
||||||
|
validChars := "abcdefghijklmnopqrstuvwxyzABCDE" +
|
||||||
|
"FGHIJKLMNOPQRSTUVWXYZ0123456789-."
|
||||||
|
|
||||||
|
// Generate hostname length between 1 and 255 characters.
|
||||||
|
length := rapid.IntRange(1, 255).Draw(t, "hostname_length")
|
||||||
|
|
||||||
|
hostname := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
charIndex := rapid.IntRange(0, len(validChars)-1).Draw(
|
||||||
|
t, fmt.Sprintf("char_%d", i),
|
||||||
|
)
|
||||||
|
hostname[i] = validChars[charIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(hostname)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user