diff --git a/channeldb/migration30/revocation_log.go b/channeldb/migration30/revocation_log.go index 51d92b72e..24382e707 100644 --- a/channeldb/migration30/revocation_log.go +++ b/channeldb/migration30/revocation_log.go @@ -16,8 +16,18 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -// OutputIndexEmpty is used when the output index doesn't exist. -const OutputIndexEmpty = math.MaxUint16 +const ( + // OutputIndexEmpty is used when the output index doesn't exist. + OutputIndexEmpty = math.MaxUint16 + + // A set of tlv type definitions used to serialize the body of + // revocation logs to the database. + // + // NOTE: A migration should be added whenever this list changes. + revLogOurOutputIndexType tlv.Type = 0 + revLogTheirOutputIndexType tlv.Type = 1 + revLogCommitTxHashType tlv.Type = 2 +) var ( // revocationLogBucketDeprecated is dedicated for storing the necessary @@ -208,29 +218,6 @@ type RevocationLog struct { HTLCEntries []*HTLCEntry } -// toTlvStream converts an RevocationLog record into a tlv representation. -func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. We define it here instead - // of the head of the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - ourOutputIndexType tlv.Type = 0 - theirOutputIndexType tlv.Type = 1 - commitTxHashType tlv.Type = 2 - ) - - return tlv.NewStream( - tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex), - tlv.MakePrimitiveRecord( - theirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash), - ) -} - // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // ChannelCommitment to construct a revocation log entry and saves them to // disk. It also saves our output index and their output index, which are @@ -304,8 +291,21 @@ func fetchRevocationLog(log kvdb.RBucket, // serializeRevocationLog serializes a RevocationLog record based on tlv // format. func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { + // Add the tlv records for all non-optional fields. + records := []tlv.Record{ + tlv.MakePrimitiveRecord( + revLogOurOutputIndexType, &rl.OurOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogTheirOutputIndexType, &rl.TheirOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogCommitTxHashType, &rl.CommitTxHash, + ), + } + // Create the tlv stream. - tlvStream, err := rl.toTlvStream() + tlvStream, err := tlv.NewStream(records...) if err != nil { return err } @@ -351,13 +351,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { var rl RevocationLog // Create the tlv stream. - tlvStream, err := rl.toTlvStream() - if err != nil { - return rl, err - } + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord( + revLogOurOutputIndexType, &rl.OurOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogTheirOutputIndexType, &rl.TheirOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogCommitTxHashType, &rl.CommitTxHash, + ), + ) // Read the tlv stream. - if err := readTlvStream(r, tlvStream); err != nil { + if _, err := readTlvStream(r, tlvStream); err != nil { return rl, err } @@ -382,7 +389,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { } // Read the HTLC entry. - if err := readTlvStream(r, tlvStream); err != nil { + if _, err := readTlvStream(r, tlvStream); err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -427,7 +434,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { // readTlvStream is a helper function that decodes the tlv stream from the // reader. -func readTlvStream(r io.Reader, s *tlv.Stream) error { +func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) { var bodyLen uint64 // Read the stream's length. @@ -436,16 +443,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an // invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // TODO(yy): add overflow check. lr := io.LimitReader(r, int64(bodyLen)) - return s.Decode(lr) + + return s.DecodeWithParsedTypes(lr) } // fetchLogBucket returns a read bucket by visiting both the new and the old diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index e34c7e70b..30405ba8a 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -12,8 +12,18 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -// OutputIndexEmpty is used when the output index doesn't exist. -const OutputIndexEmpty = math.MaxUint16 +const ( + // OutputIndexEmpty is used when the output index doesn't exist. + OutputIndexEmpty = math.MaxUint16 + + // A set of tlv type definitions used to serialize the body of + // revocation logs to the database. + // + // NOTE: A migration should be added whenever this list changes. + revLogOurOutputIndexType tlv.Type = 0 + revLogTheirOutputIndexType tlv.Type = 1 + revLogCommitTxHashType tlv.Type = 2 +) var ( // revocationLogBucketDeprecated is dedicated for storing the necessary @@ -196,29 +206,6 @@ type RevocationLog struct { HTLCEntries []*HTLCEntry } -// toTlvStream converts an RevocationLog record into a tlv representation. -func (rl *RevocationLog) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. We define it here instead - // of the head of the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - ourOutputIndexType tlv.Type = 0 - theirOutputIndexType tlv.Type = 1 - commitTxHashType tlv.Type = 2 - ) - - return tlv.NewStream( - tlv.MakePrimitiveRecord(ourOutputIndexType, &rl.OurOutputIndex), - tlv.MakePrimitiveRecord( - theirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord(commitTxHashType, &rl.CommitTxHash), - ) -} - // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // ChannelCommitment to construct a revocation log entry and saves them to // disk. It also saves our output index and their output index, which are @@ -292,8 +279,21 @@ func fetchRevocationLog(log kvdb.RBucket, // serializeRevocationLog serializes a RevocationLog record based on tlv // format. func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { + // Add the tlv records for all non-optional fields. + records := []tlv.Record{ + tlv.MakePrimitiveRecord( + revLogOurOutputIndexType, &rl.OurOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogTheirOutputIndexType, &rl.TheirOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogCommitTxHashType, &rl.CommitTxHash, + ), + } + // Create the tlv stream. - tlvStream, err := rl.toTlvStream() + tlvStream, err := tlv.NewStream(records...) if err != nil { return err } @@ -339,13 +339,20 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { var rl RevocationLog // Create the tlv stream. - tlvStream, err := rl.toTlvStream() - if err != nil { - return rl, err - } + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord( + revLogOurOutputIndexType, &rl.OurOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogTheirOutputIndexType, &rl.TheirOutputIndex, + ), + tlv.MakePrimitiveRecord( + revLogCommitTxHashType, &rl.CommitTxHash, + ), + ) // Read the tlv stream. - if err := readTlvStream(r, tlvStream); err != nil { + if _, err := readTlvStream(r, tlvStream); err != nil { return rl, err } @@ -370,7 +377,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { } // Read the HTLC entry. - if err := readTlvStream(r, tlvStream); err != nil { + if _, err := readTlvStream(r, tlvStream); err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -415,7 +422,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { // readTlvStream is a helper function that decodes the tlv stream from the // reader. -func readTlvStream(r io.Reader, s *tlv.Stream) error { +func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) { var bodyLen uint64 // Read the stream's length. @@ -424,16 +431,17 @@ func readTlvStream(r io.Reader, s *tlv.Stream) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an // invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // TODO(yy): add overflow check. lr := io.LimitReader(r, int64(bodyLen)) - return s.Decode(lr) + + return s.DecodeWithParsedTypes(lr) } // fetchOldRevocationLog finds the revocation log from the deprecated diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 405955f46..6ca5b33b0 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -127,7 +127,7 @@ func TestReadTLVStream(t *testing.T) { // Read the tlv stream. buf := bytes.NewBuffer(testValueBytes) - err = readTlvStream(buf, ts) + _, err = readTlvStream(buf, ts) require.NoError(t, err) // Check the bytes are read as expected. @@ -150,7 +150,7 @@ func TestReadTLVStreamErr(t *testing.T) { // Read the tlv stream. buf := bytes.NewBuffer(b) - err = readTlvStream(buf, ts) + _, err = readTlvStream(buf, ts) require.ErrorIs(t, err, io.ErrUnexpectedEOF) // Check the bytes are not read.