diff --git a/fn/list.go b/fn/list.go index d76df4f29..41f1bdd84 100644 --- a/fn/list.go +++ b/fn/list.go @@ -300,3 +300,18 @@ func (l *List[A]) PushFrontList(other *List[A]) { n = n.Prev() } } + +// Filter gives a slice of all of the node values that satisfy the given +// predicate. +func (l *List[A]) Filter(f Pred[A]) []A { + var acc []A + + for cursor := l.Front(); cursor != nil; cursor = cursor.Next() { + a := cursor.Value + if f(a) { + acc = append(acc, a) + } + } + + return acc +} diff --git a/fn/list_test.go b/fn/list_test.go index c08122cc4..efe2c5a92 100644 --- a/fn/list_test.go +++ b/fn/list_test.go @@ -5,6 +5,9 @@ import ( "reflect" "testing" "testing/quick" + + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" ) func GenList(r *rand.Rand) *List[uint32] { @@ -727,3 +730,93 @@ func TestMoveUnknownMark(t *testing.T) { checkList(t, &l1, []int{1}) checkList(t, &l2, []int{2}) } + +// TestFilterIdempotence ensures that the slice coming out of List.Filter is +// the same as that slice filtered again by the same predicate. +func TestFilterIdempotence(t *testing.T) { + require.NoError( + t, quick.Check( + func(l *List[uint32], modSize uint32) bool { + pred := func(a uint32) bool { + return a%modSize != 0 + } + + filtered := l.Filter(pred) + + filteredAgain := Filter(pred, filtered) + + return slices.Equal(filtered, filteredAgain) + }, + &quick.Config{ + Values: func(vs []reflect.Value, r *rand.Rand) { + l := GenList(r) + vs[0] = reflect.ValueOf(l) + vs[1] = reflect.ValueOf( + r.Uint32()%5 + 1, + ) + }, + }, + ), + ) +} + +// TestFilterShrinks ensures that the length of the slice returned from +// List.Filter is never larger than the length of the List. +func TestFilterShrinks(t *testing.T) { + require.NoError( + t, quick.Check( + func(l *List[uint32], modSize uint32) bool { + pred := func(a uint32) bool { + return a%modSize != 0 + } + + filteredSize := len(l.Filter(pred)) + + return filteredSize <= l.Len() + }, + &quick.Config{ + Values: func(vs []reflect.Value, r *rand.Rand) { + l := GenList(r) + vs[0] = reflect.ValueOf(l) + vs[1] = reflect.ValueOf( + r.Uint32()%5 + 1, + ) + }, + }, + ), + ) +} + +// TestFilterLawOfExcludedMiddle ensures that if we intersect a List.Filter +// with its negation that the intersection is the empty set. +func TestFilterLawOfExcludedMiddle(t *testing.T) { + require.NoError( + t, quick.Check( + func(l *List[uint32], modSize uint32) bool { + pred := func(a uint32) bool { + return a%modSize != 0 + } + + negatedPred := func(a uint32) bool { + return !pred(a) + } + + positive := NewSet(l.Filter(pred)...) + negative := NewSet(l.Filter(negatedPred)...) + + return positive.Intersect(negative).Equal( + NewSet[uint32](), + ) + }, + &quick.Config{ + Values: func(vs []reflect.Value, r *rand.Rand) { + l := GenList(r) + vs[0] = reflect.ValueOf(l) + vs[1] = reflect.ValueOf( + r.Uint32()%5 + 1, + ) + }, + }, + ), + ) +}