diff --git a/dag/set.go b/dag/set.go index fc16e801b..2e5d85b2d 100644 --- a/dag/set.go +++ b/dag/set.go @@ -38,17 +38,18 @@ func (s Set) Include(v interface{}) bool { // Intersection computes the set intersection with other. func (s Set) Intersection(other Set) Set { result := make(Set) - if s == nil { + if s == nil || other == nil { return result } - if other != nil { - for _, v := range s { - if other.Include(v) { - result.Add(v) - } + // Iteration over a smaller set has better performance. + if other.Len() < s.Len() { + s, other = other, s + } + for _, v := range s { + if other.Include(v) { + result.Add(v) } } - return result } diff --git a/dag/set_test.go b/dag/set_test.go index 36bd6a65b..d59eacfbd 100644 --- a/dag/set_test.go +++ b/dag/set_test.go @@ -119,3 +119,31 @@ func TestSetCopy(t *testing.T) { } } + +func makeSet(n int) Set { + ret := make(Set, n) + for i := 0; i < n; i++ { + ret.Add(i) + } + return ret +} + +func BenchmarkSetIntersection_100_100000(b *testing.B) { + small := makeSet(100) + large := makeSet(100000) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + small.Intersection(large) + } +} + +func BenchmarkSetIntersection_100000_100(b *testing.B) { + small := makeSet(100) + large := makeSet(100000) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + large.Intersection(small) + } +}