diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 8400c210e..20edcbab4 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -76,6 +76,9 @@ func (a addressType) AddrLen() uint16 { // WriteElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. +// +// TODO(yy): rm this method once we finish dereferencing it from other +// packages. func WriteElement(w *bytes.Buffer, element interface{}) error { switch e := element.(type) { case NodeAlias: @@ -433,6 +436,9 @@ func WriteElement(w *bytes.Buffer, element interface{}) error { // WriteElements is writes each element in the elements slice to the passed // buffer using WriteElement. +// +// TODO(yy): rm this method once we finish dereferencing it from other +// packages. func WriteElements(buf *bytes.Buffer, elements ...interface{}) error { for _, element := range elements { err := WriteElement(buf, element) @@ -823,8 +829,11 @@ func ReadElement(r io.Reader, element interface{}) error { length := binary.BigEndian.Uint16(addrLen[:]) var addrBytes [deliveryAddressMaxSize]byte + if length > deliveryAddressMaxSize { - return fmt.Errorf("cannot read %d bytes into addrBytes", length) + return fmt.Errorf( + "cannot read %d bytes into addrBytes", length, + ) } if _, err = io.ReadFull(r, addrBytes[:length]); err != nil { return err diff --git a/lnwire/writer.go b/lnwire/writer.go new file mode 100644 index 000000000..6a20c5715 --- /dev/null +++ b/lnwire/writer.go @@ -0,0 +1,431 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "image/color" + "math" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/tor" +) + +var ( + // ErrNilFeatureVector is returned when the supplied feature is nil. + ErrNilFeatureVector = errors.New("cannot write nil feature vector") + + // ErrPkScriptTooLong is returned when the length of the provided + // script exceeds 34. + ErrPkScriptTooLong = errors.New("'PkScript' too long") + + // ErrNilTCPAddress is returned when the supplied address is nil. + ErrNilTCPAddress = errors.New("cannot write nil TCPAddr") + + // ErrNilOnionAddress is returned when the supplied address is nil. + ErrNilOnionAddress = errors.New("cannot write nil onion address") + + // ErrNilNetAddress is returned when a nil value is used in []net.Addr. + ErrNilNetAddress = errors.New("cannot write nil address") + + // ErrNilPublicKey is returned when a nil pubkey is used. + ErrNilPublicKey = errors.New("cannot write nil pubkey") + + // ErrUnknownServiceLength is returned when the onion service length is + // unknown. + ErrUnknownServiceLength = errors.New("unknown onion service length") +) + +// ErrOutpointIndexTooBig is used when the outpoint index exceeds the max value +// of uint16. +func ErrOutpointIndexTooBig(index uint32) error { + return fmt.Errorf( + "index for outpoint (%v) is greater than "+ + "max index of %v", index, math.MaxUint16, + ) +} + +// WriteBytes appends the given bytes to the provided buffer. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WriteBytes(buf *bytes.Buffer, b []byte) error { // nolint: interfacer + _, err := buf.Write(b) + return err +} + +// WriteUint8 appends the uint8 to the provided buffer. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WriteUint8(buf *bytes.Buffer, n uint8) error { // nolint: interfacer + _, err := buf.Write([]byte{n}) + return err +} + +// WriteUint16 appends the uint16 to the provided buffer. It encodes the +// integer using big endian byte order. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WriteUint16(buf *bytes.Buffer, n uint16) error { // nolint: interfacer + var b [2]byte + binary.BigEndian.PutUint16(b[:], n) + _, err := buf.Write(b[:]) + return err +} + +// WriteUint32 appends the uint32 to the provided buffer. It encodes the +// integer using big endian byte order. +func WriteUint32(buf *bytes.Buffer, n uint32) error { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + _, err := buf.Write(b[:]) + return err +} + +// WriteUint64 appends the uint64 to the provided buffer. It encodes the +// integer using big endian byte order. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WriteUint64(buf *bytes.Buffer, n uint64) error { // nolint: interfacer + var b [8]byte + binary.BigEndian.PutUint64(b[:], n) + _, err := buf.Write(b[:]) + return err +} + +// WriteSatoshi appends the Satoshi value to the provided buffer. +func WriteSatoshi(buf *bytes.Buffer, amount btcutil.Amount) error { + return WriteUint64(buf, uint64(amount)) +} + +// WriteMilliSatoshi appends the MilliSatoshi value to the provided buffer. +func WriteMilliSatoshi(buf *bytes.Buffer, amount MilliSatoshi) error { + return WriteUint64(buf, uint64(amount)) +} + +// WritePublicKey appends the compressed public key to the provided buffer. +func WritePublicKey(buf *bytes.Buffer, pub *btcec.PublicKey) error { + if pub == nil { + return ErrNilPublicKey + } + + serializedPubkey := pub.SerializeCompressed() + return WriteBytes(buf, serializedPubkey) + +} + +// WriteChannelID appends the ChannelID to the provided buffer. +func WriteChannelID(buf *bytes.Buffer, channelID ChannelID) error { + return WriteBytes(buf, channelID[:]) +} + +// WriteNodeAlias appends the alias to the provided buffer. +func WriteNodeAlias(buf *bytes.Buffer, alias NodeAlias) error { + return WriteBytes(buf, alias[:]) +} + +// WriteShortChannelID appends the ShortChannelID to the provided buffer. It +// encodes the BlockHeight and TxIndex each using 3 bytes with big endian byte +// order, and encodes txPosition using 2 bytes with big endian byte order. +func WriteShortChannelID(buf *bytes.Buffer, shortChanID ShortChannelID) error { + // Check that field fit in 3 bytes and write the blockHeight + if shortChanID.BlockHeight > ((1 << 24) - 1) { + return errors.New("block height should fit in 3 bytes") + } + + var blockHeight [4]byte + binary.BigEndian.PutUint32(blockHeight[:], shortChanID.BlockHeight) + + if _, err := buf.Write(blockHeight[1:]); err != nil { + return err + } + + // Check that field fit in 3 bytes and write the txIndex + if shortChanID.TxIndex > ((1 << 24) - 1) { + return errors.New("tx index should fit in 3 bytes") + } + + var txIndex [4]byte + binary.BigEndian.PutUint32(txIndex[:], shortChanID.TxIndex) + if _, err := buf.Write(txIndex[1:]); err != nil { + return err + } + + // Write the TxPosition + return WriteUint16(buf, shortChanID.TxPosition) +} + +// WriteSig appends the signature to the provided buffer. +func WriteSig(buf *bytes.Buffer, sig Sig) error { + return WriteBytes(buf, sig[:]) +} + +// WriteSigs appends the slice of signatures to the provided buffer with its +// length. +func WriteSigs(buf *bytes.Buffer, sigs []Sig) error { + // Write the length of the sigs. + if err := WriteUint16(buf, uint16(len(sigs))); err != nil { + return err + } + + for _, sig := range sigs { + if err := WriteSig(buf, sig); err != nil { + return err + } + } + return nil +} + +// WriteFailCode appends the FailCode to the provided buffer. +func WriteFailCode(buf *bytes.Buffer, e FailCode) error { + return WriteUint16(buf, uint16(e)) +} + +// WriteRawFeatureVector encodes the feature using the feature's Encode method +// and appends the data to the provided buffer. An error will return if the +// passed feature is nil. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WriteRawFeatureVector(buf *bytes.Buffer, // nolint: interfacer + feature *RawFeatureVector) error { + + if feature == nil { + return ErrNilFeatureVector + } + + return feature.Encode(buf) +} + +// WriteColorRGBA appends the RGBA color using three bytes. +func WriteColorRGBA(buf *bytes.Buffer, e color.RGBA) error { + // Write R + if err := WriteUint8(buf, e.R); err != nil { + return err + } + + // Write G + if err := WriteUint8(buf, e.G); err != nil { + return err + } + + // Write B + return WriteUint8(buf, e.B) +} + +// WriteShortChanIDEncoding appends the ShortChanIDEncoding to the provided +// buffer. +func WriteShortChanIDEncoding(buf *bytes.Buffer, e ShortChanIDEncoding) error { + return WriteUint8(buf, uint8(e)) +} + +// WriteFundingFlag appends the FundingFlag to the provided buffer. +func WriteFundingFlag(buf *bytes.Buffer, flag FundingFlag) error { + return WriteUint8(buf, uint8(flag)) +} + +// WriteChanUpdateMsgFlags appends the update flag to the provided buffer. +func WriteChanUpdateMsgFlags(buf *bytes.Buffer, f ChanUpdateMsgFlags) error { + return WriteUint8(buf, uint8(f)) +} + +// WriteChanUpdateChanFlags appends the update flag to the provided buffer. +func WriteChanUpdateChanFlags(buf *bytes.Buffer, f ChanUpdateChanFlags) error { + return WriteUint8(buf, uint8(f)) +} + +// WriteDeliveryAddress appends the address to the provided buffer. +func WriteDeliveryAddress(buf *bytes.Buffer, addr DeliveryAddress) error { + return writeDataWithLength(buf, addr) +} + +// WritePingPayload appends the payload to the provided buffer. +func WritePingPayload(buf *bytes.Buffer, payload PingPayload) error { + return writeDataWithLength(buf, payload) +} + +// WritePongPayload appends the payload to the provided buffer. +func WritePongPayload(buf *bytes.Buffer, payload PongPayload) error { + return writeDataWithLength(buf, payload) +} + +// WriteErrorData appends the data to the provided buffer. +func WriteErrorData(buf *bytes.Buffer, data ErrorData) error { + return writeDataWithLength(buf, data) +} + +// WriteOpaqueReason appends the reason to the provided buffer. +func WriteOpaqueReason(buf *bytes.Buffer, reason OpaqueReason) error { + return writeDataWithLength(buf, reason) +} + +// WriteBool appends the boolean to the provided buffer. +func WriteBool(buf *bytes.Buffer, b bool) error { + if b { + return WriteBytes(buf, []byte{1}) + } + return WriteBytes(buf, []byte{0}) +} + +// WritePkScript appends the script to the provided buffer. Returns an error if +// the provided script exceeds 34 bytes. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func WritePkScript(buf *bytes.Buffer, s PkScript) error { // nolint: interfacer + // The largest script we'll accept is a p2wsh which is exactly + // 34 bytes long. + scriptLength := len(s) + if scriptLength > 34 { + return ErrPkScriptTooLong + } + + return wire.WriteVarBytes(buf, 0, s) +} + +// WriteOutPoint appends the outpoint to the provided buffer. +func WriteOutPoint(buf *bytes.Buffer, p wire.OutPoint) error { + // Before we write anything to the buffer, check the Index is sane. + if p.Index > math.MaxUint16 { + return ErrOutpointIndexTooBig(p.Index) + } + + var h [32]byte + copy(h[:], p.Hash[:]) + if _, err := buf.Write(h[:]); err != nil { + return err + } + + // Write the index using two bytes. + return WriteUint16(buf, uint16(p.Index)) +} + +// WriteTCPAddr appends the TCP address to the provided buffer, either a IPv4 +// or a IPv6. +func WriteTCPAddr(buf *bytes.Buffer, addr *net.TCPAddr) error { + if addr == nil { + return ErrNilTCPAddress + } + + // Make a slice of bytes to hold the data of descriptor and ip. At + // most, we need 17 bytes - 1 byte for the descriptor, 16 bytes for + // IPv6. + data := make([]byte, 0, 17) + + if addr.IP.To4() != nil { + data = append(data, uint8(tcp4Addr)) + data = append(data, addr.IP.To4()...) + } else { + data = append(data, uint8(tcp6Addr)) + data = append(data, addr.IP.To16()...) + } + + if _, err := buf.Write(data); err != nil { + return err + } + + return WriteUint16(buf, uint16(addr.Port)) +} + +// WriteOnionAddr appends the onion address to the provided buffer. +func WriteOnionAddr(buf *bytes.Buffer, addr *tor.OnionAddr) error { + if addr == nil { + return ErrNilOnionAddress + } + + var ( + suffixIndex int + descriptor []byte + ) + + // Decide the suffixIndex and descriptor. + switch len(addr.OnionService) { + case tor.V2Len: + descriptor = []byte{byte(v2OnionAddr)} + suffixIndex = tor.V2Len - tor.OnionSuffixLen + + case tor.V3Len: + descriptor = []byte{byte(v3OnionAddr)} + suffixIndex = tor.V3Len - tor.OnionSuffixLen + + default: + return ErrUnknownServiceLength + } + + // Decode the address. + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:suffixIndex], + ) + if err != nil { + return err + } + + // Perform the actual write when the above checks passed. + if _, err := buf.Write(descriptor); err != nil { + return err + } + if _, err := buf.Write(host); err != nil { + return err + } + + return WriteUint16(buf, uint16(addr.Port)) +} + +// WriteNetAddrs appends a slice of addresses to the provided buffer with the +// length info. +func WriteNetAddrs(buf *bytes.Buffer, addresses []net.Addr) error { + // First, we'll encode all the addresses into an intermediate + // buffer. We need to do this in order to compute the total + // length of the addresses. + buffer := make([]byte, 0, MaxMsgBody) + addrBuf := bytes.NewBuffer(buffer) + + for _, address := range addresses { + switch a := address.(type) { + case *net.TCPAddr: + if err := WriteTCPAddr(addrBuf, a); err != nil { + return err + } + case *tor.OnionAddr: + if err := WriteOnionAddr(addrBuf, a); err != nil { + return err + } + default: + return ErrNilNetAddress + } + } + + // With the addresses fully encoded, we can now write out data. + return writeDataWithLength(buf, addrBuf.Bytes()) +} + +// writeDataWithLength writes the data and its length to the buffer. +// +// Note: We intentionally skip the interfacer linter check here because we want +// to have concrete type (bytes.Buffer) rather than interface type (io.Write) +// due to performance concern. +func writeDataWithLength(buf *bytes.Buffer, // nolint: interfacer + data []byte) error { + + var l [2]byte + binary.BigEndian.PutUint16(l[:], uint16(len(data))) + if _, err := buf.Write(l[:]); err != nil { + return err + } + + _, err := buf.Write(data) + return err +} diff --git a/lnwire/writer_test.go b/lnwire/writer_test.go new file mode 100644 index 000000000..96ef66103 --- /dev/null +++ b/lnwire/writer_test.go @@ -0,0 +1,618 @@ +package lnwire + +import ( + "bytes" + "encoding/base32" + "image/color" + "math" + "net" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/tor" + "github.com/stretchr/testify/require" +) + +func TestWriteBytes(t *testing.T) { + buf := new(bytes.Buffer) + data := []byte{1, 2, 3} + + err := WriteBytes(buf, data) + + require.NoError(t, err) + require.Equal(t, data, buf.Bytes()) +} + +func TestWriteUint8(t *testing.T) { + buf := new(bytes.Buffer) + data := uint8(1) + expectedBytes := []byte{1} + + err := WriteUint8(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteUint16(t *testing.T) { + buf := new(bytes.Buffer) + data := uint16(1) + expectedBytes := []byte{0, 1} + + err := WriteUint16(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteUint32(t *testing.T) { + buf := new(bytes.Buffer) + data := uint32(1) + expectedBytes := []byte{0, 0, 0, 1} + + err := WriteUint32(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteUint64(t *testing.T) { + buf := new(bytes.Buffer) + data := uint64(1) + expectedBytes := []byte{0, 0, 0, 0, 0, 0, 0, 1} + + err := WriteUint64(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteSatoshi(t *testing.T) { + buf := new(bytes.Buffer) + data := btcutil.Amount(1) + expectedBytes := []byte{0, 0, 0, 0, 0, 0, 0, 1} + + err := WriteSatoshi(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteMilliSatoshi(t *testing.T) { + buf := new(bytes.Buffer) + data := MilliSatoshi(1) + expectedBytes := []byte{0, 0, 0, 0, 0, 0, 0, 1} + + err := WriteMilliSatoshi(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWritePublicKey(t *testing.T) { + buf := new(bytes.Buffer) + + // Check that when nil pubkey is used, an error will return. + err := WritePublicKey(buf, nil) + require.Equal(t, ErrNilPublicKey, err) + + pub, err := randPubKey() + require.NoError(t, err) + expectedBytes := pub.SerializeCompressed() + + err = WritePublicKey(buf, pub) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteChannelID(t *testing.T) { + buf := new(bytes.Buffer) + data := ChannelID{1} + expectedBytes := [32]byte{1} + + err := WriteChannelID(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes[:], buf.Bytes()) +} + +func TestWriteNodeAlias(t *testing.T) { + buf := new(bytes.Buffer) + data := NodeAlias{1} + expectedBytes := [32]byte{1} + + err := WriteNodeAlias(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes[:], buf.Bytes()) +} + +func TestWriteShortChannelID(t *testing.T) { + buf := new(bytes.Buffer) + data := ShortChannelID{BlockHeight: 1, TxIndex: 2, TxPosition: 3} + expectedBytes := []byte{ + 0, 0, 1, // First three bytes encodes BlockHeight. + 0, 0, 2, // Second three bytes encodes TxIndex. + 0, 3, // Final two bytes encodes TxPosition. + } + + err := WriteShortChannelID(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteSig(t *testing.T) { + buf := new(bytes.Buffer) + data := Sig{1, 2, 3} + expectedBytes := [64]byte{1, 2, 3} + + err := WriteSig(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes[:], buf.Bytes()) +} + +func TestWriteSigs(t *testing.T) { + buf := new(bytes.Buffer) + sig1, sig2, sig3 := Sig{1}, Sig{2}, Sig{3} + data := []Sig{sig1, sig2, sig3} + + // First two bytes encode the length of the slice. + expectedBytes := []byte{0, 3} + expectedBytes = append(expectedBytes, sig1[:]...) + expectedBytes = append(expectedBytes, sig2[:]...) + expectedBytes = append(expectedBytes, sig3[:]...) + + err := WriteSigs(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteFailCode(t *testing.T) { + buf := new(bytes.Buffer) + data := FailCode(1) + expectedBytes := []byte{0, 1} + + err := WriteFailCode(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +// TODO(yy): expand the test to cover more encoding scenarios. +func TestWriteRawFeatureVector(t *testing.T) { + buf := new(bytes.Buffer) + + // Check that when nil feature is used, an error will return. + err := WriteRawFeatureVector(buf, nil) + require.Equal(t, ErrNilFeatureVector, err) + + // Create a raw feature vector. + feature := &RawFeatureVector{features: map[FeatureBit]bool{ + InitialRoutingSync: true, // FeatureBit 3. + }} + expectedBytes := []byte{ + 0, 1, // First two bytes encode the length. + 8, // Last byte encodes the feature bit (1 << 3). + } + + err = WriteRawFeatureVector(buf, feature) + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteColorRGBA(t *testing.T) { + buf := new(bytes.Buffer) + data := color.RGBA{R: 1, G: 2, B: 3} + expectedBytes := []byte{1, 2, 3} + + err := WriteColorRGBA(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteShortChanIDEncoding(t *testing.T) { + buf := new(bytes.Buffer) + data := ShortChanIDEncoding(1) + expectedBytes := []byte{1} + + err := WriteShortChanIDEncoding(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteFundingFlag(t *testing.T) { + buf := new(bytes.Buffer) + data := FundingFlag(1) + expectedBytes := []byte{1} + + err := WriteFundingFlag(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteChanUpdateMsgFlags(t *testing.T) { + buf := new(bytes.Buffer) + data := ChanUpdateMsgFlags(1) + expectedBytes := []byte{1} + + err := WriteChanUpdateMsgFlags(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteChanUpdateChanFlags(t *testing.T) { + buf := new(bytes.Buffer) + data := ChanUpdateChanFlags(1) + expectedBytes := []byte{1} + + err := WriteChanUpdateChanFlags(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteDeliveryAddress(t *testing.T) { + buf := new(bytes.Buffer) + data := DeliveryAddress{1, 1, 1} + expectedBytes := []byte{ + 0, 3, // First two bytes encode the length. + 1, 1, 1, // The actual data. + } + + err := WriteDeliveryAddress(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWritePingPayload(t *testing.T) { + buf := new(bytes.Buffer) + data := PingPayload{1, 1, 1} + expectedBytes := []byte{ + 0, 3, // First two bytes encode the length. + 1, 1, 1, // The actual data. + } + + err := WritePingPayload(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWritePongPayload(t *testing.T) { + buf := new(bytes.Buffer) + data := PongPayload{1, 1, 1} + expectedBytes := []byte{ + 0, 3, // First two bytes encode the length. + 1, 1, 1, // The actual data. + } + + err := WritePongPayload(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteErrorData(t *testing.T) { + buf := new(bytes.Buffer) + data := ErrorData{1, 1, 1} + expectedBytes := []byte{ + 0, 3, // First two bytes encode the length. + 1, 1, 1, // The actual data. + } + + err := WriteErrorData(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteOpaqueReason(t *testing.T) { + buf := new(bytes.Buffer) + data := OpaqueReason{1, 1, 1} + expectedBytes := []byte{ + 0, 3, // First two bytes encode the length. + 1, 1, 1, // The actual data. + } + + err := WriteOpaqueReason(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteBool(t *testing.T) { + buf := new(bytes.Buffer) + + // Test write true. + data := true + expectedBytes := []byte{1} + + err := WriteBool(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) + + // Test write false. + data = false + expectedBytes = append(expectedBytes, 0) + + err = WriteBool(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWritePkScript(t *testing.T) { + buf := new(bytes.Buffer) + + // Write a very long script to check the error is returned as expected. + script := PkScript{} + zeros := [35]byte{} + script = append(script, zeros[:]...) + err := WritePkScript(buf, script) + require.Equal(t, ErrPkScriptTooLong, err) + + data := PkScript{1, 1, 1} + expectedBytes := []byte{ + 3, // First byte encodes the length. + 1, 1, 1, // The actual data. + } + + err = WritePkScript(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteOutPoint(t *testing.T) { + buf := new(bytes.Buffer) + + // Create an outpoint with very large index to check the error is + // returned as expected. + outpointWrong := wire.OutPoint{Index: math.MaxUint16 + 1} + err := WriteOutPoint(buf, outpointWrong) + require.Equal(t, ErrOutpointIndexTooBig(outpointWrong.Index), err) + + // Now check the normal write succeeds. + hash := chainhash.Hash{1} + data := wire.OutPoint{Index: 2, Hash: hash} + expectedBytes := []byte{} + expectedBytes = append(expectedBytes, hash[:]...) + expectedBytes = append(expectedBytes, []byte{0, 2}...) + + err = WriteOutPoint(buf, data) + + require.NoError(t, err) + require.Equal(t, expectedBytes, buf.Bytes()) +} + +func TestWriteTCPAddr(t *testing.T) { + buf := new(bytes.Buffer) + + testCases := []struct { + name string + addr *net.TCPAddr + + expectedErr error + expectedBytes []byte + }{ + { + // Check that the error is returned when nil address is + // used. + name: "nil address err", + addr: nil, + expectedErr: ErrNilTCPAddress, + expectedBytes: nil, + }, + { + // Check write IPv4. + name: "write ipv4", + addr: &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 8080, + }, + expectedErr: nil, + expectedBytes: []byte{ + 0x1, // The addressType. + 0x7f, 0x0, 0x0, 0x1, // The IP. + 0x1f, 0x90, // The port (31 * 256 + 144). + }, + }, + { + // Check write IPv6. + name: "write ipv6", + addr: &net.TCPAddr{ + IP: net.IP{ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + }, + Port: 8080, + }, + expectedErr: nil, + expectedBytes: []byte{ + 0x2, // The addressType. + // The IP. + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1f, 0x90, // The port (31 * 256 + 144). + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + oldLen := buf.Len() + + err := WriteTCPAddr(buf, tc.addr) + require.Equal(t, tc.expectedErr, err) + + bytesWritten := buf.Bytes()[oldLen:buf.Len()] + require.Equal(t, tc.expectedBytes, bytesWritten) + }) + } +} + +func TestWriteOnionAddr(t *testing.T) { + buf := new(bytes.Buffer) + + testCases := []struct { + name string + addr *tor.OnionAddr + + expectedErr error + expectedBytes []byte + }{ + { + // Check that the error is returned when nil address is + // used. + name: "nil address err", + addr: nil, + expectedErr: ErrNilOnionAddress, + expectedBytes: nil, + }, + { + // Check the error is returned when an invalid onion + // address is used. + name: "wrong onion service length", + addr: &tor.OnionAddr{OnionService: "wrong"}, + expectedErr: ErrUnknownServiceLength, + expectedBytes: nil, + }, + { + // Check when the address has invalid base32 encoding, + // the error is returned. + name: "invalid base32 encoding", + addr: &tor.OnionAddr{ + OnionService: "1234567890123456.onion", + }, + expectedErr: base32.CorruptInputError(0), + expectedBytes: nil, + }, + { + // Check write onion v2. + name: "onion address v2", + addr: &tor.OnionAddr{ + OnionService: "abcdefghijklmnop.onion", + Port: 9065, + }, + expectedErr: nil, + expectedBytes: []byte{ + 0x3, // The descriptor. + 0x0, 0x44, 0x32, 0x14, 0xc7, // The host. + 0x42, 0x54, 0xb6, 0x35, 0xcf, + 0x23, 0x69, // The port. + }, + }, + { + // Check write onion v3. + name: "onion address v3", + addr: &tor.OnionAddr{ + OnionService: "abcdefghij" + + "abcdefghijabcdefghij" + + "abcdefghijabcdefghij" + + "234567.onion", + Port: 9065, + }, + expectedErr: nil, + expectedBytes: []byte{ + 0x4, // The descriptor. + 0x0, 0x44, 0x32, 0x14, 0xc7, 0x42, 0x40, + 0x11, 0xc, 0x85, 0x31, 0xd0, 0x90, 0x4, + 0x43, 0x21, 0x4c, 0x74, 0x24, 0x1, 0x10, + 0xc8, 0x53, 0x1d, 0x9, 0x0, 0x44, 0x32, + 0x14, 0xc7, 0x42, 0x75, 0xbe, 0x77, 0xdf, + 0x23, 0x69, // The port. + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + oldLen := buf.Len() + + err := WriteOnionAddr(buf, tc.addr) + require.Equal(t, tc.expectedErr, err) + + bytesWritten := buf.Bytes()[oldLen:buf.Len()] + require.Equal(t, tc.expectedBytes, bytesWritten) + }) + } +} + +func TestWriteNetAddrs(t *testing.T) { + buf := new(bytes.Buffer) + tcpAddr := &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 8080, + } + onionAddr := &tor.OnionAddr{ + OnionService: "abcdefghijklmnop.onion", + Port: 9065, + } + + testCases := []struct { + name string + addr []net.Addr + + expectedErr error + expectedBytes []byte + }{ + { + // Check that the error is returned when nil address is + // used. + name: "nil address err", + addr: []net.Addr{nil, tcpAddr, onionAddr}, + expectedErr: ErrNilNetAddress, + expectedBytes: nil, + }, + { + // Check empty address slice. + name: "empty address slice", + addr: []net.Addr{}, + expectedErr: nil, + // Use two bytes to encode the address size. + expectedBytes: []byte{0, 0}, + }, + { + // Check a successful writes of a slice of addresses. + name: "two addresses", + addr: []net.Addr{tcpAddr, onionAddr}, + expectedErr: nil, + expectedBytes: []byte{ + // 7 bytes for TCP and 13 bytes for onion. + 0x0, 0x14, + // TCP address. + 0x1, 0x7f, 0x0, 0x0, 0x1, 0x1f, 0x90, + // Onion address. + 0x3, 0x0, 0x44, 0x32, 0x14, 0xc7, 0x42, + 0x54, 0xb6, 0x35, 0xcf, 0x23, 0x69, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + oldLen := buf.Len() + + err := WriteNetAddrs(buf, tc.addr) + require.Equal(t, tc.expectedErr, err) + + bytesWritten := buf.Bytes()[oldLen:buf.Len()] + require.Equal(t, tc.expectedBytes, bytesWritten) + }) + } +}