diff --git a/lnwire/features.go b/lnwire/features.go index 70aa23154..98202841f 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -214,19 +214,51 @@ var Features = map[FeatureBit]string{ // can be serialized and deserialized to/from a byte representation that is // transmitted in Lightning network messages. type RawFeatureVector struct { - features map[FeatureBit]bool + features map[FeatureBit]struct{} } // NewRawFeatureVector creates a feature vector with all of the feature bits // given as arguments enabled. func NewRawFeatureVector(bits ...FeatureBit) *RawFeatureVector { - fv := &RawFeatureVector{features: make(map[FeatureBit]bool)} + fv := &RawFeatureVector{features: make(map[FeatureBit]struct{})} for _, bit := range bits { fv.Set(bit) } return fv } +// IsEmpty returns whether the feature vector contains any feature bits. +func (fv RawFeatureVector) IsEmpty() bool { + return len(fv.features) == 0 +} + +// OnlyContains determines whether only the specified feature bits are found. +func (fv RawFeatureVector) OnlyContains(bits ...FeatureBit) bool { + if len(bits) != len(fv.features) { + return false + } + for _, bit := range bits { + if !fv.IsSet(bit) { + return false + } + } + return true +} + +// Equals determines whether two features vectors contain exactly the same +// features. +func (fv RawFeatureVector) Equals(other *RawFeatureVector) bool { + if len(fv.features) != len(other.features) { + return false + } + for bit := range fv.features { + if _, ok := other.features[bit]; !ok { + return false + } + } + return true +} + // Merges sets all feature bits in other on the receiver's feature vector. func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error { for bit := range other.features { @@ -249,12 +281,13 @@ func (fv *RawFeatureVector) Clone() *RawFeatureVector { // IsSet returns whether a particular feature bit is enabled in the vector. func (fv *RawFeatureVector) IsSet(feature FeatureBit) bool { - return fv.features[feature] + _, ok := fv.features[feature] + return ok } // Set marks a feature as enabled in the vector. func (fv *RawFeatureVector) Set(feature FeatureBit) { - fv.features[feature] = true + fv.features[feature] = struct{}{} } // SafeSet sets the chosen feature bit in the feature vector, but returns an diff --git a/lnwire/features_test.go b/lnwire/features_test.go index 3eed2b1b4..d4be13d4c 100644 --- a/lnwire/features_test.go +++ b/lnwire/features_test.go @@ -353,3 +353,47 @@ func TestFeatures(t *testing.T) { }) } } + +func TestRawFeatureVectorOnlyContains(t *testing.T) { + t.Parallel() + + features := []FeatureBit{ + StaticRemoteKeyOptional, + AnchorsZeroFeeHtlcTxOptional, + ExplicitChannelTypeRequired, + } + fv := NewRawFeatureVector(features...) + require.True(t, fv.OnlyContains(features...)) + require.False(t, fv.OnlyContains(features[:1]...)) +} + +func TestEqualRawFeatureVectors(t *testing.T) { + t.Parallel() + + a := NewRawFeatureVector( + StaticRemoteKeyOptional, + AnchorsZeroFeeHtlcTxOptional, + ExplicitChannelTypeRequired, + ) + b := a.Clone() + require.True(t, a.Equals(b)) + + b.Unset(ExplicitChannelTypeRequired) + require.False(t, a.Equals(b)) + + b.Set(ExplicitChannelTypeOptional) + require.False(t, a.Equals(b)) +} + +func TestIsEmptyFeatureVector(t *testing.T) { + t.Parallel() + + fv := NewRawFeatureVector() + require.True(t, fv.IsEmpty()) + + fv.Set(StaticRemoteKeyOptional) + require.False(t, fv.IsEmpty()) + + fv.Unset(StaticRemoteKeyOptional) + require.True(t, fv.IsEmpty()) +} diff --git a/lnwire/writer_test.go b/lnwire/writer_test.go index 96ef66103..3d16b1336 100644 --- a/lnwire/writer_test.go +++ b/lnwire/writer_test.go @@ -193,9 +193,11 @@ func TestWriteRawFeatureVector(t *testing.T) { require.Equal(t, ErrNilFeatureVector, err) // Create a raw feature vector. - feature := &RawFeatureVector{features: map[FeatureBit]bool{ - InitialRoutingSync: true, // FeatureBit 3. - }} + feature := &RawFeatureVector{ + features: map[FeatureBit]struct{}{ + InitialRoutingSync: {}, // FeatureBit 3. + }, + } expectedBytes := []byte{ 0, 1, // First two bytes encode the length. 8, // Last byte encodes the feature bit (1 << 3).