diff --git a/dag/dag.go b/dag/dag.go index c53ec284a..c41255d1f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,6 +2,7 @@ package dag import ( "fmt" + "sync" ) // AcyclicGraph is a specialization of Graph that cannot have cycles. With @@ -37,5 +38,48 @@ func (g *AcyclicGraph) Root() (Vertex, error) { } // Walk walks the graph, calling your callback as each node is visited. -func (g *AcyclicGraph) Walk(cb WalkFunc) { +// This will walk nodes in parallel if it can. +func (g *AcyclicGraph) Walk(cb WalkFunc) error { + // We require a root to walk. + root, err := g.Root() + if err != nil { + return err + } + + // Build the waitgroup that signals when we're done + var wg sync.WaitGroup + wg.Add(g.vertices.Len()) + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + wg.Wait() + }() + + // Start walking! + visitCh := make(chan Vertex, g.vertices.Len()) + visitCh <- root + for { + select { + case v := <-visitCh: + go g.walkVertex(v, cb, visitCh, &wg) + case <-doneCh: + goto WALKDONE + } + } + +WALKDONE: + return nil +} + +func (g *AcyclicGraph) walkVertex( + v Vertex, cb WalkFunc, nextCh chan<- Vertex, wg *sync.WaitGroup) { + defer wg.Done() + + // Call the callback on this vertex + cb(v) + + // Walk all the children in parallel + for _, v := range g.DownEdges(v).List() { + nextCh <- v.(Vertex) + } } diff --git a/dag/dag_test.go b/dag/dag_test.go index 94f25d6e9..2bbea81e4 100644 --- a/dag/dag_test.go +++ b/dag/dag_test.go @@ -1,6 +1,8 @@ package dag import ( + "reflect" + "sync" "testing" ) @@ -44,3 +46,35 @@ func TestAcyclicGraphRoot_multiple(t *testing.T) { t.Fatal("should error") } } + +func TestAcyclicGraphWalk(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(3, 2)) + g.Connect(BasicEdge(3, 1)) + + var visits []Vertex + var lock sync.Mutex + err := g.Walk(func(v Vertex) { + lock.Lock() + defer lock.Unlock() + visits = append(visits, v) + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := [][]Vertex{ + {3, 1, 2}, + {3, 2, 1}, + } + for _, e := range expected { + if reflect.DeepEqual(visits, e) { + return + } + } + + t.Fatalf("bad: %#v", visits) +}