dag: method for filtering a set on arbitrary criteria

This commit is contained in:
Martin Atkins 2017-05-10 17:55:11 -07:00
parent 510733ffd3
commit b28fc1cd20
2 changed files with 56 additions and 0 deletions

View File

@ -81,6 +81,20 @@ func (s *Set) Difference(other *Set) *Set {
return result
}
// Filter returns a set that contains the elements from the receiver
// where the given callback returns true.
func (s *Set) Filter(cb func(interface{}) bool) *Set {
result := new(Set)
for _, v := range s.m {
if cb(v) {
result.Add(v)
}
}
return result
}
// Len is the number of items in the set.
func (s *Set) Len() int {
if s == nil {

View File

@ -54,3 +54,45 @@ func TestSetDifference(t *testing.T) {
})
}
}
func TestSetFilter(t *testing.T) {
cases := []struct {
Input []interface{}
Expected []interface{}
}{
{
[]interface{}{1, 2, 3},
[]interface{}{1, 2, 3},
},
{
[]interface{}{4, 5, 6},
[]interface{}{4},
},
{
[]interface{}{7, 8, 9},
[]interface{}{},
},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) {
var input, expected Set
for _, v := range tc.Input {
input.Add(v)
}
for _, v := range tc.Expected {
expected.Add(v)
}
actual := input.Filter(func(v interface{}) bool {
return v.(int) < 5
})
match := actual.Intersection(&expected)
if match.Len() != expected.Len() {
t.Fatalf("bad: %#v", actual.List())
}
})
}
}