diff --git a/lnwire/features.go b/lnwire/features.go index 76dab3b05..21b7fbc02 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -3,6 +3,7 @@ package lnwire import ( "encoding/binary" "errors" + "fmt" "io" ) @@ -10,6 +11,11 @@ var ( // ErrFeaturePairExists signals an error in feature vector construction // where the opposing bit in a feature pair has already been set. ErrFeaturePairExists = errors.New("feature pair exists") + + // ErrFeatureStandard is returned when attempts to modify LND's known + // set of features are made. + ErrFeatureStandard = errors.New("feature is used in standard " + + "protocol set") ) // FeatureBit represents a feature that can be enabled in either a local or @@ -344,6 +350,57 @@ func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error { return nil } +// ValidateUpdate checks whether a feature vector can safely be updated to the +// new feature vector provided, checking that it does not alter any of the +// "standard" features that are defined by LND. The new feature vector should +// be inclusive of all features in the original vector that it still wants to +// advertise, setting and unsetting updates as desired. +func (fv *RawFeatureVector) ValidateUpdate(other *RawFeatureVector) error { + // Run through the new set of features and check that we're not adding + // any feature bits that are defined but not set in LND. + for feature := range other.features { + if fv.IsSet(feature) { + continue + } + + if name, known := Features[feature]; known { + return fmt.Errorf("can't set feature "+ + "bit %d (%v): %w", feature, name, + ErrFeatureStandard) + } + } + + // Check that the new feature vector for this set does not unset any + // features that are standard in LND by comparing the features in our + // current set to the omitted values in the new set. + for feature := range fv.features { + if other.IsSet(feature) { + continue + } + + if name, known := Features[feature]; known { + return fmt.Errorf("can't unset feature "+ + "bit %d (%v): %w", feature, name, + ErrFeatureStandard) + } + } + + return nil +} + +// ValidatePairs checks each feature bit in a raw vector to ensure that the +// opposing bit is not set, validating that the vector has either the optional +// or required bit set, not both. +func (fv *RawFeatureVector) ValidatePairs() error { + for feature := range fv.features { + if _, ok := fv.features[feature^1]; ok { + return ErrFeaturePairExists + } + } + + return nil +} + // Clone makes a copy of a feature vector. func (fv *RawFeatureVector) Clone() *RawFeatureVector { newFeatures := NewRawFeatureVector() diff --git a/lnwire/features_test.go b/lnwire/features_test.go index a945eeb3e..84053b963 100644 --- a/lnwire/features_test.go +++ b/lnwire/features_test.go @@ -396,3 +396,103 @@ func TestIsEmptyFeatureVector(t *testing.T) { fv.Unset(StaticRemoteKeyOptional) require.True(t, fv.IsEmpty()) } + +// TestValidatePairs tests that feature vectors can only set the required or +// optional feature bit in a pair, not both. +func TestValidatePairs(t *testing.T) { + t.Parallel() + + rfv := NewRawFeatureVector( + StaticRemoteKeyOptional, + StaticRemoteKeyRequired, + ) + require.Equal(t, ErrFeaturePairExists, rfv.ValidatePairs()) + + rfv = NewRawFeatureVector( + StaticRemoteKeyOptional, + PaymentAddrRequired, + ) + require.Nil(t, rfv.ValidatePairs()) +} + +// TestValidateUpdate tests validation of an update to a feature vector. +func TestValidateUpdate(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + currentFeatures []FeatureBit + newFeatures []FeatureBit + err error + }{ + { + name: "defined feature bit set, can include", + currentFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + newFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + err: nil, + }, + { + name: "defined feature bit not already set", + currentFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + newFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + PaymentAddrRequired, + }, + err: ErrFeatureStandard, + }, + { + name: "known feature missing", + currentFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + PaymentAddrRequired, + }, + newFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + err: ErrFeatureStandard, + }, + { + name: "can set unknown feature", + currentFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + newFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + FeatureBit(1001), + }, + err: nil, + }, + { + name: "can unset unknown feature", + currentFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + FeatureBit(1001), + }, + newFeatures: []FeatureBit{ + StaticRemoteKeyOptional, + }, + err: nil, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + currentFV := NewRawFeatureVector( + testCase.currentFeatures..., + ) + newFV := NewRawFeatureVector(testCase.newFeatures...) + + err := currentFV.ValidateUpdate(newFV) + require.ErrorIs(t, err, testCase.err) + }) + } +}