dag: Walk

This commit is contained in:
Mitchell Hashimoto 2015-02-04 10:29:03 -05:00
parent cfa3d89265
commit e94c43e0dc
2 changed files with 79 additions and 1 deletions

View File

@ -2,6 +2,7 @@ package dag
import ( import (
"fmt" "fmt"
"sync"
) )
// AcyclicGraph is a specialization of Graph that cannot have cycles. With // AcyclicGraph is a specialization of Graph that cannot have cycles. With
@ -37,5 +38,48 @@ func (g *AcyclicGraph) Root() (Vertex, error) {
} }
// Walk walks the graph, calling your callback as each node is visited. // Walk walks the graph, calling your callback as each node is visited.
func (g *AcyclicGraph) Walk(cb WalkFunc) { // This will walk nodes in parallel if it can.
func (g *AcyclicGraph) Walk(cb WalkFunc) error {
// We require a root to walk.
root, err := g.Root()
if err != nil {
return err
}
// Build the waitgroup that signals when we're done
var wg sync.WaitGroup
wg.Add(g.vertices.Len())
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
wg.Wait()
}()
// Start walking!
visitCh := make(chan Vertex, g.vertices.Len())
visitCh <- root
for {
select {
case v := <-visitCh:
go g.walkVertex(v, cb, visitCh, &wg)
case <-doneCh:
goto WALKDONE
}
}
WALKDONE:
return nil
}
func (g *AcyclicGraph) walkVertex(
v Vertex, cb WalkFunc, nextCh chan<- Vertex, wg *sync.WaitGroup) {
defer wg.Done()
// Call the callback on this vertex
cb(v)
// Walk all the children in parallel
for _, v := range g.DownEdges(v).List() {
nextCh <- v.(Vertex)
}
} }

View File

@ -1,6 +1,8 @@
package dag package dag
import ( import (
"reflect"
"sync"
"testing" "testing"
) )
@ -44,3 +46,35 @@ func TestAcyclicGraphRoot_multiple(t *testing.T) {
t.Fatal("should error") t.Fatal("should error")
} }
} }
func TestAcyclicGraphWalk(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(3, 1))
var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) {
lock.Lock()
defer lock.Unlock()
visits = append(visits, v)
})
if err != nil {
t.Fatalf("err: %s", err)
}
expected := [][]Vertex{
{3, 1, 2},
{3, 2, 1},
}
for _, e := range expected {
if reflect.DeepEqual(visits, e) {
return
}
}
t.Fatalf("bad: %#v", visits)
}