diff --git a/feature/deps.go b/feature/deps.go index 64b4f2fc9..3371aa3a0 100644 --- a/feature/deps.go +++ b/feature/deps.go @@ -98,6 +98,58 @@ func ValidateDeps(fv *lnwire.FeatureVector) error { return validateDeps(features, supported) } +// SetBit sets the given feature bit on the given feature bit vector along with +// any of its dependencies. If the bit is required, then all the dependencies +// are also set to required, otherwise, the optional dependency bits are set. +// Existing bits are only upgraded from optional to required but never +// downgraded from required to optional. +func SetBit(vector *lnwire.FeatureVector, + bit lnwire.FeatureBit) *lnwire.FeatureVector { + + fv := vector.Clone() + + // Get the optional version of the bit since that is what the deps map + // uses. + optBit := mapToOptional(bit) + + // If the bit we are setting is optional, then we set it (in its + // optional form) and also set all its dependents as optional if they + // are not already set (they may already be set in a required form in + // which case they should not be overridden). + if !bit.IsRequired() { + // Set the bit itself if it does not already exist. We use + // SafeSet here so that if the bit already exists in the + // required form, then this is not overwritten. + _ = fv.SafeSet(bit) + + // Do the same for all the dependent bits. + for depBit := range deps[optBit] { + fv = SetBit(fv, depBit) + } + + return fv + } + + // The bit is required. In this case, we do want to override any + // existing optional bit for both the bit itself and for the dependent + // bits. + fv.Unset(optBit) + fv.Set(bit) + + // Do the same for all the dependent bits. + for depBit := range deps[optBit] { + // The deps map only contains the optional versions of bits, so + // there is no need to first map the bit to the optional + // version. + fv.Unset(depBit) + + // Set the required version of the bit instead. + fv = SetBit(fv, mapToRequired(depBit)) + } + + return fv +} + // validateDeps is a subroutine that recursively checks that the passed features // have all of their associated dependencies in the supported map. func validateDeps(features featureSet, supported supportedFeatures) error { @@ -157,3 +209,13 @@ func mapToOptional(bit lnwire.FeatureBit) lnwire.FeatureBit { } return bit } + +// mapToRequired returns the required variant of a given feature bit pair. +func mapToRequired(bit lnwire.FeatureBit) lnwire.FeatureBit { + if bit.IsRequired() { + return bit + } + bit ^= 0x01 + + return bit +} diff --git a/feature/deps_test.go b/feature/deps_test.go index 4116e6ea9..9b6b02fa0 100644 --- a/feature/deps_test.go +++ b/feature/deps_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) type depTest struct { @@ -164,3 +165,170 @@ func testValidateDeps(t *testing.T, test depTest) { test.expErr, err) } } + +// TestSettingDepBits sets that the SetBit function correctly sets a bit along +// with its dependencies in a feature vector. Specifically, we want to check +// that any existing optional bits are upgraded to required if the main bit +// being set is required. Similarly, if the main bit is optional, then any +// existing bits that depend on it should not be downgraded from required to +// optional. +func TestSettingDepBits(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + existingVector *lnwire.RawFeatureVector + newBit lnwire.FeatureBit + expectedVector *lnwire.RawFeatureVector + }{ + { + name: "Optional bit with no dependants", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.ExplicitChannelTypeOptional, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.ExplicitChannelTypeOptional, + ), + }, + { + name: "Required bit with no dependants", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.ExplicitChannelTypeRequired, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.ExplicitChannelTypeRequired, + ), + }, + { + name: "Optional bit with single " + + "level dependant", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.RouteBlindingOptional, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.RouteBlindingOptional, + lnwire.TLVOnionPayloadOptional, + ), + }, + { + name: "Required bit with single " + + "level dependant", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.RouteBlindingRequired, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.RouteBlindingRequired, + lnwire.TLVOnionPayloadRequired, + ), + }, + { + name: "Optional bit with multi level " + + "dependants", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.Bolt11BlindedPathsOptional, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.Bolt11BlindedPathsOptional, + lnwire.RouteBlindingOptional, + lnwire.TLVOnionPayloadOptional, + ), + }, + { + name: "Required bit with multi level " + + "dependants", + existingVector: lnwire.NewRawFeatureVector(), + newBit: lnwire.Bolt11BlindedPathsRequired, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.Bolt11BlindedPathsRequired, + lnwire.RouteBlindingRequired, + lnwire.TLVOnionPayloadRequired, + ), + }, + { + name: "Existing required bit should not be " + + "overridden if new bit is optional", + existingVector: lnwire.NewRawFeatureVector( + lnwire.TLVOnionPayloadRequired, + ), + newBit: lnwire.Bolt11BlindedPathsOptional, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.Bolt11BlindedPathsOptional, + lnwire.RouteBlindingOptional, + lnwire.TLVOnionPayloadRequired, + ), + }, + { + name: "Existing optional bit should be overridden if " + + "new bit is required", + existingVector: lnwire.NewRawFeatureVector( + lnwire.TLVOnionPayloadOptional, + ), + newBit: lnwire.Bolt11BlindedPathsRequired, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.Bolt11BlindedPathsRequired, + lnwire.RouteBlindingRequired, + lnwire.TLVOnionPayloadRequired, + ), + }, + { + name: "Unrelated bits should not be affected", + existingVector: lnwire.NewRawFeatureVector( + lnwire.AMPOptional, + lnwire.TLVOnionPayloadOptional, + ), + newBit: lnwire.Bolt11BlindedPathsRequired, + expectedVector: lnwire.NewRawFeatureVector( + lnwire.AMPOptional, + lnwire.Bolt11BlindedPathsRequired, + lnwire.RouteBlindingRequired, + lnwire.TLVOnionPayloadRequired, + ), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fv := lnwire.NewFeatureVector( + test.existingVector, lnwire.Features, + ) + + resultFV := SetBit(fv, test.newBit) + require.Equal( + t, test.expectedVector, + resultFV.RawFeatureVector, + ) + }) + } +} + +// TestSetBitNoCycles tests the SetBit call for each feature bit that we know of +// in both its optional and required form. This ensures that the SetBit call +// never gets stuck in a recursion cycle for any feature bit. +func TestSetBitNoCycles(t *testing.T) { + t.Parallel() + + // For each feature-bit that we are aware of (both optional and + // required), we will create a feature vector that is empty, and then + // we will call SetBit with the given feature bit. We then check that + // all the dependent features are also added in the appropriate form + // (optional vs required). This test completing demonstrates that the + // recursion in SetBit is not a problem since no feature bits should + // create a dependency cycle. + for bit := range lnwire.Features { + fv := lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), lnwire.Features, + ) + + resultFV := SetBit(fv, bit) + + // Ensure that all the dependent feature bits are in fact set + // in the resulting set. Here we just check that some form + // (optional or required) is set. The expected type is asserted + // later on in the test. + for expectedBit := range deps[bit] { + require.True(t, resultFV.IsSet(expectedBit) || + resultFV.IsSet(mapToRequired(expectedBit))) + } + + // Make sure all the resulting feature bits have the correct + // form (optional vs required). + for depBit := range resultFV.Features() { + require.Equal(t, bit.IsRequired(), depBit.IsRequired()) + } + } +}